{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "56cc68c1-3d7e-4d47-9394-cf698b3a0a81",
   "metadata": {},
   "source": [
    "# Taking the pulse of reform: What can a multi-channel analysis of media and parliamentary speeches tell us about public acceptability?\n",
    "\n",
    "## Description\n",
    "This notebook contains all the Python code used in the paper: *\"Taking the Pulse of Reform: What Can a Multi-Channel Analysis of Media and Parliamentary Speeches Tell Us About Public Acceptability?\"*\n",
    "\n",
    "## Structure\n",
    "1. **Data Loading**  \n",
    "   - Load and preprocess parliamentary speeches (media articles are preprocessed in a dedicated script for time constraints).  \n",
    "2. **Word Embeddings for Dimension Classification**  \n",
    "   - Train a Word2Vec model.  \n",
    "   - Obtain document embeddings computed as the centroid of the relative word-vectors.  \n",
    "   - Measure the IDF-weighted similarity between document embeddings and the four dimensions defined by keywords.  \n",
    "3. **Slicing the Data**  \n",
    "   - Identify media articles and parliamentary speeches related to the four dimensions.  \n",
    "4. **Non-negative Matrix Factorisation (NMF)**  \n",
    "   - Explore the media debate within each dimension.  \n",
    "   - Visualize the dynamics of topic coverage over time using time series.  \n",
    "5. **Coupled Matrix Factorisation (CMF)**  \n",
    "   - Compare media articles and parliamentary speeches within the four dimensions. \n",
    "\n",
    "## Authors\n",
    "**Simone Maria Parazzoli** (ISI Foundation, Turin, Italy)  \n",
    "**Michele Tizzani** (ISI Foundation, Turin, Italy)  \n",
    "**Marco Quaggiotto** (ISI Foundation, Turin, Italy; Department of Design, Politecnico di Milano, Milano, Italy)  \n",
    "**Fabrice Murtin** (Centre on Well-Being, Inclusion, Sustainability and Equal Opportunity, OECD, Paris, France)  \n",
    "**Neil Martin** (Centre on Well-Being, Inclusion, Sustainability and Equal Opportunity, OECD, Paris, France)  \n",
    "**Nicolò Gozzi** (ISI Foundation, Turin, Italy)  \n",
    "**Lætitia Gauvin** (ISI Foundation, Turin, Italy; Institute for Research on Sustainable Development, Aubervilliers, France)\n",
    "\n",
    "## Contact\n",
    "simoneparazzoli@gmail.com\n",
    "\n",
    "## Last Updated\n",
    "26 December 2024"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ba52348",
   "metadata": {},
   "source": [
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb03439b-de0d-44cb-9718-07c1a57a1f08",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import gensim\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "tqdm.pandas()\n",
    "\n",
    "import newspaper\n",
    "import spacy\n",
    "nlp = spacy.load(\"fr_core_news_sm\")\n",
    "from spacy.lang.fr.stop_words import STOP_WORDS\n",
    "import string \n",
    "from collections import Counter\n",
    "from scipy import spatial\n",
    "\n",
    "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.decomposition import NMF\n",
    "\n",
    "import matcouply.decomposition as decomposition\n",
    "from matcouply.coupled_matrices import CoupledMatrixFactorization\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "import seaborn as sns\n",
    "import plotly.graph_objs as go\n",
    "import datetime\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "911b579e-0b7c-4f9f-a47a-a00f798865c5",
   "metadata": {},
   "source": [
    "# Data prep"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7507236a-fc35-4137-941a-f5ee63fd623e",
   "metadata": {},
   "source": [
    "## Parliament"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "819e39ba-6028-469f-8375-6c119a85a692",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import data from session on retraite and speeches mentioning retraite\n",
    "df_title_retraite = pd.read_csv('data/parliament_title_reform.csv') # Speeches from sessions with title referring to the reform\n",
    "df_speech_retraite = pd.read_csv('data/parliament_speeches_reform.csv') # Speeches mentioning the reform from sessions not dedicated to it\n",
    "\n",
    "# Merge data from session on retraite and speeches mentioning retraite\n",
    "df_parl = pd.concat([df_title_retraite, df_speech_retraite])\n",
    "df_parl = df_parl.dropna(subset='speech').drop_duplicates(subset='speech')\n",
    "\n",
    "# Filter speeches from November 2022 onwards\n",
    "df_parl['dateSeance'] = pd.to_datetime(df_parl['dateSeance'], format='%Y%m%d%H%M%S%f')\n",
    "df_parl = df_parl[df_parl['dateSeance'] >= pd.to_datetime('2022-11-01')]\n",
    "df_parl.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35be60d1-fbfd-462f-8342-e1b0204eda08",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data\n",
    "stopwords = set(STOP_WORDS)\n",
    "custom_sw_parl = {'réforme','retraite','retrait','oui','no','avis','mme','monsieur','parole','pouvoir','faire','devoir','bien',\n",
    "                  'falloir','prendre','aller','bon','vouloir','bien','favorable','sous','remettre'}\n",
    "stopwords.update(custom_sw_parl)\n",
    "\n",
    "def lemmatizer(text):\n",
    "    doc = nlp(text)\n",
    "    lemmas = [token.lemma_ for token in doc]\n",
    "    return lemmas\n",
    "\n",
    "def whitespace_remover(tokens):\n",
    "    filtered_tokens = [token for token in tokens if token.strip() != '']\n",
    "    filtered_tokens_newline = [token for token in filtered_tokens if token != '\\n\\n']\n",
    "    return filtered_tokens_newline\n",
    "    \n",
    "def sw_remover(tokens):\n",
    "    filtered_tokens = [token for token in tokens if token.lower() not in stopwords]\n",
    "    return filtered_tokens\n",
    "\n",
    "def punct_remover(tokens):\n",
    "    filtered_tokens = [token for token in tokens if not any(char in string.punctuation for char in token)] # I changed from all to any!!\n",
    "    return filtered_tokens\n",
    "\n",
    "def lowercaser(tokens):\n",
    "    filtered_tokens = [token.lower() for token in tokens]\n",
    "    return filtered_tokens\n",
    "\n",
    "def joiner(tokens):\n",
    "    joined_tokens = ' '.join(tokens)\n",
    "    return joined_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51510f6b-e0fb-476c-a806-3d402fe1af3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create new columns 'TextLemmatized' and 'TextProcessed'\n",
    "df_parl['TextLemmatized'] = df_parl['speech'].progress_apply(lemmatizer)\n",
    "df_parl['TextProcessed'] = df_parl['TextLemmatized'].progress_apply(whitespace_remover)\n",
    "df_parl['TextProcessed'] = df_parl['TextProcessed'].progress_apply(sw_remover)\n",
    "df_parl['TextProcessed'] = df_parl['TextProcessed'].progress_apply(punct_remover)\n",
    "df_parl['TextProcessed'] = df_parl['TextProcessed'].progress_apply(lowercaser)\n",
    "df_parl['TextProcessed'] = df_parl['TextProcessed'].progress_apply(joiner)\n",
    "\n",
    "df_parl['TextProcessed']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d8ea0f1-f946-4069-8743-a175d6aee2b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the parliamentary dataset for analysis\n",
    "df_parl_text = df_parl[['dateSeance','TextProcessed','speech']]\n",
    "df_parl_text.rename(columns={\"dateSeance\": \"Date\", \"TextProcessed\": \"TextProcessed\", \"speech\":\"Text\"}, inplace = True)\n",
    "df_parl_text['Source'] = 'Parliament'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a91795-49c3-4969-9686-e6f2ae3c0378",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exclude irrelevant articles\n",
    "df_parl_text = df_parl_text[~df_parl_text['TextProcessed'].str.contains('orpea')]\n",
    "df_parl_text.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72724da2-083e-4906-9545-c0d3e64fdbd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_parl_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15cc66f3-398d-4dd7-9112-0859b0c50c73",
   "metadata": {},
   "source": [
    "## Media"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f1ff170",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import data\n",
    "df_selectedmedia = pd.read_csv('data/media_articles_reform.csv')\n",
    "df_selectedmedia['DATE'] = pd.to_datetime(df_selectedmedia['DATE'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7bdd605",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract text from articles' URL\n",
    "def extract_text(url):\n",
    "    try:\n",
    "        article = newspaper.Article(url)\n",
    "        article.download()\n",
    "        article.parse()\n",
    "        return article.text\n",
    "    except:\n",
    "        return \"\"\n",
    "    \n",
    "df_selectedmedia['Text'] = df_selectedmedia['DocumentIdentifier'].apply(lambda x: extract_text(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1749485",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create new columns 'TextLemmatized' and 'TextProcessed'\n",
    "df_selectedmedia['TextLemmatized'] = df_selectedmedia['Text'].progress_apply(lemmatizer)\n",
    "df_selectedmedia['TextProcessed'] = df_selectedmedia['TextLemmatized'].progress_apply(whitespace_remover)\n",
    "df_selectedmedia['TextProcessed'] = df_selectedmedia['TextProcessed'].progress_apply(sw_remover)\n",
    "df_selectedmedia['TextProcessed'] = df_selectedmedia['TextProcessed'].progress_apply(punct_remover)\n",
    "df_selectedmedia['TextProcessed'] = df_selectedmedia['TextProcessed'].progress_apply(lowercaser)\n",
    "df_selectedmedia['TextProcessed'] = df_selectedmedia['TextProcessed'].progress_apply(joiner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd6c4c1d-a23a-4b38-991a-6af428d916d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lowercase processed articles\n",
    "df_selectedmedia['TextProcessed'] = df_selectedmedia['TextProcessed'].str.lower()\n",
    "df_selectedmedia.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43814afb-260e-4a0a-88d1-a02f1ad76b90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the media dataset for analysis\n",
    "df_selectedmedia_text = df_selectedmedia[['DATE','TextProcessed','Text']]\n",
    "df_selectedmedia_text.rename(columns={\"DATE\": \"Date\", \"TextProcessed\": \"TextProcessed\",'Text':'Text'}, inplace = True)\n",
    "df_selectedmedia_text['Source'] = 'Media'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d82af4-d221-4ebf-8e69-906bd9377b3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exclude irrelevant articles\n",
    "df_selectedmedia_text = df_selectedmedia_text[~df_selectedmedia_text['TextProcessed'].str.contains('orpea')]\n",
    "df_selectedmedia_text.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee5b5837-233c-4e9b-8f43-9c1f812f8f11",
   "metadata": {},
   "source": [
    "### Time series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c35e8748-f3f4-4b7d-839b-412e133c4fc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resample to daily frequency, counting the number of entries per day\n",
    "media_count = df_selectedmedia_text.resample('D', on='Date').size().reset_index(name='Count')\n",
    "parl_count = df_parl_text.resample('D', on='Date').size().reset_index(name='Count')\n",
    "\n",
    "# Ensure all dates are included in the range\n",
    "all_dates = pd.date_range(start=min(media_count['Date'].min(), parl_count['Date'].min()), \n",
    "                          end=max(media_count['Date'].max(), parl_count['Date'].max()), freq='D')\n",
    "\n",
    "media_count = media_count.set_index('Date').reindex(all_dates, fill_value=0).rename_axis('Date').reset_index()\n",
    "parl_count = parl_count.set_index('Date').reindex(all_dates, fill_value=0).rename_axis('Date').reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9623dc22-4457-4fc5-9637-41b393771b84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define important dates \n",
    "important_dates = {\n",
    "    datetime.datetime(2022, 12, 31): 'Macron\\nannounces\\nthe reform',\n",
    "    datetime.datetime(2023, 1, 30): 'Law is\\nintroduced\\nin AN',\n",
    "    datetime.datetime(2023, 3, 16): 'PM invokes\\nArt. 49.3',\n",
    "    datetime.datetime(2023, 4, 14): 'Consitutional\\nCouncil ratifies\\nthe law',\n",
    "}\n",
    "\n",
    "# Plotting\n",
    "plt.figure(figsize=(12, 4))\n",
    "ax = plt.gca()\n",
    "\n",
    "# Plot media article volume with rolling mean\n",
    "plt.plot(media_count['Date'], media_count['Count'].rolling(3).mean(), label='Media articles', color='#DC143C')\n",
    "plt.plot(parl_count['Date'], parl_count['Count'].rolling(3).mean(), label='Parliamentary speeches', color='#7B68EE')\n",
    "\n",
    "# Customize the plot\n",
    "plt.xlabel('Date')\n",
    "plt.ylabel('Volume')\n",
    "plt.legend()\n",
    "#plt.title('Volume of Media Articles Over Time')\n",
    "\n",
    "ax.margins(x=0)\n",
    "# Remove top and right spines\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=0.1)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7423783c-23e8-4ece-b8ed-36793278988f",
   "metadata": {},
   "source": [
    "## Merge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d04dc9-482d-472c-afcb-32c76479b52b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine parliamentary and media datasets for unified analysis\n",
    "df_merge = pd.concat([df_parl_text, df_selectedmedia_text])\n",
    "df_merge"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c7af1db-ff36-4d54-a3db-876f5b0215f3",
   "metadata": {},
   "source": [
    "# Similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f43f9c3-0b27-42d5-8815-5113e3a2e905",
   "metadata": {},
   "source": [
    "## Word2Vec and centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b551f94-bece-432b-8b1e-c293ea393254",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert text columns to lists for processing\n",
    "documents = df_merge['Text'].astype(str).to_list()\n",
    "\n",
    "documents_processed = df_merge['TextProcessed']\n",
    "documents_processed = documents_processed.astype(str).to_list()\n",
    "\n",
    "# Tokenize each document into individual words\n",
    "documents_tokenised = []\n",
    "for doc in documents_processed:\n",
    "  tokens = doc.split()\n",
    "  documents_tokenised.append(tokens)\n",
    "\n",
    "len(documents_tokenised)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "151f21d6-fa5c-4219-99ea-02fb80f53119",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training Word2Vec model from the merged documents\n",
    "model = gensim.models.Word2Vec(documents_tokenised, vector_size=100, window=5, min_count=1, workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7882f7cb-3ff9-4401-8cbb-2f97c8cd10bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the centroid of word embeddings for each document\n",
    "document_centroids_in = []\n",
    "\n",
    "for doc in documents_tokenised:\n",
    "\n",
    "  document_vecs = []\n",
    "  \n",
    "  for word in doc:\n",
    "    if word in model.wv:\n",
    "      word_vec = model.wv[word]\n",
    "      document_vecs.append(word_vec)\n",
    "      \n",
    "  if document_vecs: # I had to add this here because I obtained warning indicating there are some empty speech_vecs lists resulted in NaN when mean\n",
    "      document_centroid = np.mean(document_vecs, axis=0)\n",
    "      \n",
    "  document_centroids_in.append(document_centroid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99eaa8d8-859d-4f74-b711-c1d605fad881",
   "metadata": {},
   "source": [
    "## IDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf4ed78a-473a-49d2-8e7b-5f9c3ae6362e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate the Inverse Document Frequency (IDF) for terms across a set of documents\n",
    "def calculate_idf(documents):\n",
    "    \n",
    "    N = len(documents)\n",
    "    tD = Counter()\n",
    "\n",
    "    # Count the number of documents that contain each term\n",
    "    for document in documents:\n",
    "        features = set(document.lower().split())\n",
    "        for f in features:\n",
    "            tD[f] += 1\n",
    "\n",
    "    IDF = {}\n",
    "    for term, term_frequency in tD.items():\n",
    "        # Calculate the IDF for the term\n",
    "        term_IDF = np.log(float(N) / term_frequency)\n",
    "        IDF[term] = term_IDF\n",
    "\n",
    "    return IDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a780b32b-118a-485f-b9a9-bdd3ed8557fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain the IDF values for each word in the vocabulary\n",
    "voc_idf = calculate_idf(documents_processed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e754c91b-393a-401f-9fd3-e1b481c7c3d4",
   "metadata": {},
   "source": [
    "## TF-IDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20b08ced-66be-4cb6-a8c5-e4da322678ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate TF-IDF for all words in documents documents\n",
    "def calculate_tfidf(documents):\n",
    "    IDF = calculate_idf(documents)\n",
    "    tfidf_documents = []\n",
    "\n",
    "    # Calculate TF-IDF for each document\n",
    "    for document in documents:\n",
    "        tfidf = {}\n",
    "        word_count = document.lower().split()\n",
    "        term_counts = Counter(word_count)  # Term frequencies in the document\n",
    "        total_terms = len(word_count)      # Total number of terms in the document\n",
    "\n",
    "        for term, count in term_counts.items():\n",
    "            TF = count / total_terms  # Calculate term frequency (TF)\n",
    "            IDF_value = IDF.get(term, 0)  # Retrieve IDF value for the term\n",
    "            tfidf[term] = TF * IDF_value  # Calculate TF-IDF\n",
    "\n",
    "        tfidf_documents.append(tfidf)\n",
    "\n",
    "    return tfidf_documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1287f803-725e-47e1-bfa6-d3d1ef0f4fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain the IDF values for each word in the vocabulary\n",
    "voc_tfidf = calculate_tfidf(documents_processed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "235e5fc2-20e9-4df8-a7eb-e1627b50b469",
   "metadata": {},
   "source": [
    "## Weighted similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e025b0fc-0c02-493c-b8b0-2c153068dd47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate the weighted similarity between document vectors and keyword vectors using IDF weighting\n",
    "def keyword_sim_idfweight(keywords, doc_vecs, voc_idf):\n",
    "    doc_scores = []\n",
    "\n",
    "    for doc_vec in doc_vecs:\n",
    "        keyword_sims = []\n",
    "        total_weight = []\n",
    "\n",
    "        for keyword in keywords:\n",
    "            if keyword in model.wv:\n",
    "                sim = 1 - spatial.distance.cosine(model.wv[keyword], doc_vec)\n",
    "                idf_weight = voc_idf[keyword]                \n",
    "                weighted_sim = sim * idf_weight\n",
    "                \n",
    "                keyword_sims.append(weighted_sim)\n",
    "                total_weight.append(idf_weight)\n",
    "\n",
    "        doc_score = sum(keyword_sims) / sum(total_weight)\n",
    "\n",
    "        doc_scores.append(doc_score)\n",
    "\n",
    "    return doc_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1519e29e-dbcd-4353-bfe8-1282eb17d9d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import keywords\n",
    "kws_econ_extended = ['compte public', 'finance', 'investissement', 'croissance', 'création emploi', 'compétitivité', 'soutenabilité', 'financement', 'pérennité', 'chômage', 'emploi', 'pib', 'prix', 'coût transition', 'incitation', 'espérance vie', 'départ anticipé', 'euro', 'contribution', 'productivité', 'dette', 'déficit', 'achat', 'démographique', 'attractivité', 'inflation', 'naissance', 'population', 'prime', 'travailleur', 'valeur', 'public', 'entreprise', 'banque', 'crise', 'cdc', 'industriel', 'territoire', 'fiscal', 'insee', 'revenu', 'cotisation', 'niveau', '62', 'argent', 'agriculture', 'industrie', 'totalenergie', 'cor', '64', 'actif', 'patrimoine', 'milliard', 'taux', 'cadre', 'travail', 'budget', 'direction', 'caisse', 'million', 'capital', 'salaire', 'pension', 'paiement', 'cotiser', 'facture', 'point', 'bourse', 'consommation', 'travailler', 'immobilier']\n",
    "kws_fair_extended = ['gagnant', 'perdant', 'justice', 'morale', 'inégalité', 'riche', 'pauvre', 'pénibilité', 'carrière long', 'solidarité', 'décent', 'bonne santé', 'régime spécial', 'régime complémentaire', 'Agirc-Arrco', 'redistributif', 'femme', 'cotisation', 'équité', 'égalité', 'discrimination', 'privilège', 'richesse', 'classe sociale', 'précarité', '65', 'prime', 'fonctionnaire', 'public', 'mari', 'police', 'sncf', 'mixte', 'cheminots', 'niveau', 'edf', 'agriculture', 'industrie', 'ratp', 'fraude', 'jeunesse', 'totalenergie', 'usine', '64', 'marier', 'milliardaire', 'minimal', 'jeune', 'patrimoine', 'index', 'policier', 'taux', 'couple', 'cadre', 'plafond', 'homme', 'ouvrier', 'partage', 'social', 'capital', 'fortune', 'enseignant', 'smic', 'senior', 'agriculteur']\n",
    "kws_risktime_extended = ['assurance', 'génération futur', 'génération', 'enfant', 'protection', 'mutualisation', 'myopie', 'déséquilibre', 'compréhension', 'ajustement automatique', 'ajustement', 'décote', 'bonus', 'malus', 'long terme', 'court terme', 'durabilité', 'vulnérabilité', 'incertitude', 'volatilité', 'futur', 'sécurité', 'aléas', 'imprévisibilité', 'imprévisible', 'scénario', 'jeune', 'climat', 'dette', '65', 'crise', 'cdc', 'vie', 'climatique', 'carbone', 'jeunesse', 'cor', 'soin', 'plan', 'index', 'temps', 'population', 'enfance', 'trimestre', 'naissance', 'ehpad']\n",
    "kws_proc_extended = ['processus législatif', 'processus', 'vote', 'consultation', 'dialogue social', 'partenaire social', 'grève', 'manifestation', 'communication', 'confiance', 'polarisation', 'démocratie', '49', 'COR', 'période transition', 'motion censure', 'transparence', 'négociation', 'consensus', 'référendum', 'compromis', 'mobilisation', 'syndicat', 'liot', 'peuple', 'amendement', 'anti', 'populaire', 'crise', 'report', 'cgt', 'commission', 'police', 'intersyndicale', 'gauche', 'article', 'loi', 'leader', 'pause', 'examen', 'martinez', 'manifestant', 'motion', 'texte', 'dussopt', 'macron', 'projet', 'bloquer', 'borne', 'initiative', 'proposition', 'gouvernement', 'nupes', 'grenade', 'policier', 'censure', 'droit', 'rn', 'syndical', 'berger', 'obstruction', 'presse', 'mouvement', 'medef', 'crs', 'constitutionnel', 'lr', 'droite', 'force', 'ordre', 'bardella', 'accord', 'cnr', 'groupe', 'communiste', 'etat', 'riester', 'garde', 'rip', 'politique', 'dysfonctionnement', 'voter', 'journaliste', 'cfdt', 'revendication', 'retrait', 'juge', 'tribunal', 'maire', 'vert', 'violence']\n",
    "\n",
    "# Compute weighted similarity scores for each document based on keyword sets for different dimensions\n",
    "economy_sims_weighted = keyword_sim_idfweight(kws_econ_extended, document_centroids_in, voc_idf)\n",
    "fairness_sims_weighted = keyword_sim_idfweight(kws_fair_extended, document_centroids_in, voc_idf)\n",
    "risktime_sims_weighted = keyword_sim_idfweight(kws_risktime_extended, document_centroids_in, voc_idf)\n",
    "process_sims_weighted = keyword_sim_idfweight(kws_proc_extended, document_centroids_in, voc_idf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2156951c-9f6c-4af5-be2a-ded9246df40f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame to store the weighted similarity scores and associated document information\n",
    "sim_df = pd.DataFrame()\n",
    "\n",
    "sim_df['economy_sim_weighted'] = economy_sims_weighted\n",
    "sim_df['fairness_sim_weighted'] = fairness_sims_weighted\n",
    "sim_df['risktime_sim_weighted'] = risktime_sims_weighted\n",
    "sim_df['process_sim_weighted'] = process_sims_weighted\n",
    "\n",
    "sim_df['article'] = documents\n",
    "sim_df['article_processed'] = documents_processed\n",
    "sim_df['Date'] = df_merge['Date'].values\n",
    "sim_df['Source'] = df_merge['Source'].values\n",
    "\n",
    "sim_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9f01e5e-296b-4051-863b-f175aa3a993f",
   "metadata": {},
   "source": [
    "# Keyword count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7af70162-39e9-4643-b526-b51247c4ee92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Count keywords\n",
    "def count_keywords(article, keywords):\n",
    "  keyword_set = set(keywords)\n",
    "  found_words = set()\n",
    "  count = 0\n",
    "\n",
    "  for word in article.split():\n",
    "    if word in keyword_set:\n",
    "      if word not in found_words:\n",
    "        count += 1\n",
    "        found_words.add(word)\n",
    "\n",
    "  return count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "036a5262-8002-47a7-bcde-96312aeef3ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Count keyword occurrences\n",
    "sim_df['econ_count'] = 0\n",
    "sim_df['fair_count'] = 0\n",
    "sim_df['risktime_count'] = 0\n",
    "sim_df['proc_count'] = 0\n",
    "\n",
    "for index, row in sim_df.iterrows():\n",
    "    sim_df.loc[index,'econ_count'] = count_keywords(row['article_processed'], kws_econ_extended)\n",
    "    sim_df.loc[index,'fair_count'] = count_keywords(row['article_processed'], kws_fair_extended)\n",
    "    sim_df.loc[index,'risktime_count'] = count_keywords(row['article_processed'], kws_risktime_extended)\n",
    "    sim_df.loc[index,'proc_count'] = count_keywords(row['article_processed'], kws_proc_extended)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30ab339f-e615-4727-afd5-4288555b0cc4",
   "metadata": {},
   "source": [
    "# Cutting the corpora"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8932e4be-2935-4248-9ea8-77fad9b03f2a",
   "metadata": {},
   "source": [
    "## Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28cd30c9-eefe-442f-b19b-2b7b28afe49b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Separate the DataFrame into media and parliament subsets based on the 'Source' column\n",
    "sim_df_media = sim_df[sim_df['Source'] == 'Media']\n",
    "sim_df_parl = sim_df[sim_df['Source'] == 'Parliament']\n",
    "print(f'The shape of `sim_df_media` is {sim_df_media.shape} and the shape of `sim_df_parl is {sim_df_parl.shape}`')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54586bc0-f78b-49d7-8abf-37de2656b943",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on the number of keyword matches\n",
    "filter_threshold = 0\n",
    "\n",
    "# Sort for economy subplot\n",
    "sim_df_economy = sim_df_media[sim_df_media['econ_count'] >= filter_threshold]\n",
    "sim_df_fairness = sim_df_media[sim_df_media['fair_count'] >= filter_threshold]\n",
    "sim_df_risktime = sim_df_media[sim_df_media['risktime_count'] >= filter_threshold]\n",
    "sim_df_process = sim_df_media[sim_df_media['proc_count'] >= filter_threshold]\n",
    "sim_df\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12,8))\n",
    "\n",
    "axs[0,0].hist(sim_df_economy['economy_sim_weighted'], bins=np.arange(min(sim_df_economy['economy_sim_weighted']), max(sim_df_economy['economy_sim_weighted']) + 0.001, 0.001), color='dodgerblue')\n",
    "axs[0,1].hist(sim_df_fairness['fairness_sim_weighted'], bins=np.arange(min(sim_df_fairness['fairness_sim_weighted']), max(sim_df_fairness['fairness_sim_weighted']) + 0.001, 0.001), color='orange')  \n",
    "axs[1,0].hist(sim_df_risktime['risktime_sim_weighted'], bins=np.arange(min(sim_df_risktime['risktime_sim_weighted']), max(sim_df_risktime['risktime_sim_weighted']) + 0.001, 0.001), color='seagreen')\n",
    "axs[1,1].hist(sim_df_process['process_sim_weighted'], bins=np.arange(min(sim_df_process['process_sim_weighted']), max(sim_df_process['process_sim_weighted']) + 0.001, 0.001), color='coral')\n",
    "\n",
    "# Add titles to each subplot\n",
    "axs[0, 0].set_title(f'Economy (N = {len(sim_df_economy)})')\n",
    "axs[0, 1].set_title(f'Fairness (N = {len(sim_df_fairness)})')\n",
    "axs[1, 0].set_title(f'Risk & Time (N = {len(sim_df_risktime)})')\n",
    "axs[1, 1].set_title(f'Process (N = {len(sim_df_process)})')\n",
    "\n",
    "# Add info about stats\n",
    "economy_mean = np.mean(sim_df_economy['economy_sim_weighted'])\n",
    "economy_median = np.median(sim_df_economy['economy_sim_weighted'])\n",
    "axs[0,0].text(0.95, 0.95, f'Mean: {economy_mean:.3f}\\nMedian: {economy_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,0].transAxes)\n",
    "\n",
    "fairness_mean = np.mean(sim_df_fairness['fairness_sim_weighted'])\n",
    "fairness_median = np.median(sim_df_fairness['fairness_sim_weighted'])\n",
    "axs[0,1].text(0.95, 0.95, f'Mean: {fairness_mean:.3f}\\nMedian: {fairness_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,1].transAxes)\n",
    "\n",
    "risktime_mean = np.mean(sim_df_risktime['risktime_sim_weighted'])\n",
    "risktime_median = np.median(sim_df_risktime['risktime_sim_weighted'])\n",
    "axs[1,0].text(0.95, 0.95, f'Mean: {risktime_mean:.3f}\\nMedian: {risktime_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,0].transAxes)\n",
    "\n",
    "process_mean = np.mean(sim_df_process['process_sim_weighted'])\n",
    "process_median = np.median(sim_df_process['process_sim_weighted'])\n",
    "axs[1,1].text(0.95, 0.95, f'Mean: {process_mean:.3f}\\nMedian: {process_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,1].transAxes)\n",
    "\n",
    "# Add a general title to the figure\n",
    "fig.suptitle(f\"Articles' similarity filtering articles with ≥{filter_threshold} different keywords\", size='x-large')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21e29307-ff0e-45eb-821f-0759d045483b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on the number of keyword matches\n",
    "filter_threshold = 0\n",
    "\n",
    "# Sort for economy subplot\n",
    "sim_df_economy = sim_df_media[sim_df_media['econ_count'] >= filter_threshold]\n",
    "sim_df_fairness = sim_df_media[sim_df_media['fair_count'] >= filter_threshold]\n",
    "sim_df_risktime = sim_df_media[sim_df_media['risktime_count'] >= filter_threshold]\n",
    "sim_df_process = sim_df_media[sim_df_media['proc_count'] >= filter_threshold]\n",
    "sim_df\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12,8))\n",
    "\n",
    "axs[0,0].hist(sim_df_economy['economy_sim_weighted'], bins=np.arange(min(sim_df_economy['economy_sim_weighted']), max(sim_df_economy['economy_sim_weighted']) + 0.001, 0.001), color='dodgerblue')\n",
    "axs[0,1].hist(sim_df_fairness['fairness_sim_weighted'], bins=np.arange(min(sim_df_fairness['fairness_sim_weighted']), max(sim_df_fairness['fairness_sim_weighted']) + 0.001, 0.001), color='orange')  \n",
    "axs[1,0].hist(sim_df_risktime['risktime_sim_weighted'], bins=np.arange(min(sim_df_risktime['risktime_sim_weighted']), max(sim_df_risktime['risktime_sim_weighted']) + 0.001, 0.001), color='seagreen')\n",
    "axs[1,1].hist(sim_df_process['process_sim_weighted'], bins=np.arange(min(sim_df_process['process_sim_weighted']), max(sim_df_process['process_sim_weighted']) + 0.001, 0.001), color='coral')\n",
    "\n",
    "# Add titles to each subplot\n",
    "axs[0, 0].set_title(f'Economy')\n",
    "axs[0, 1].set_title(f'Fairness')\n",
    "axs[1, 0].set_title(f'Risk & Time')\n",
    "axs[1, 1].set_title(f'Process')\n",
    "\n",
    "# Add info about stats\n",
    "economy_mean = np.mean(sim_df_economy['economy_sim_weighted'])\n",
    "economy_median = np.median(sim_df_economy['economy_sim_weighted'])\n",
    "axs[0,0].text(0.95, 0.95, f'Mean: {economy_mean:.3f}\\nMedian: {economy_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,0].transAxes)\n",
    "\n",
    "fairness_mean = np.mean(sim_df_fairness['fairness_sim_weighted'])\n",
    "fairness_median = np.median(sim_df_fairness['fairness_sim_weighted'])\n",
    "axs[0,1].text(0.95, 0.95, f'Mean: {fairness_mean:.3f}\\nMedian: {fairness_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,1].transAxes)\n",
    "\n",
    "risktime_mean = np.mean(sim_df_risktime['risktime_sim_weighted'])\n",
    "risktime_median = np.median(sim_df_risktime['risktime_sim_weighted'])\n",
    "axs[1,0].text(0.95, 0.95, f'Mean: {risktime_mean:.3f}\\nMedian: {risktime_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,0].transAxes)\n",
    "\n",
    "process_mean = np.mean(sim_df_process['process_sim_weighted'])\n",
    "process_median = np.median(sim_df_process['process_sim_weighted'])\n",
    "axs[1,1].text(0.95, 0.95, f'Mean: {process_mean:.3f}\\nMedian: {process_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,1].transAxes)\n",
    "\n",
    "# Add a general title to the figure\n",
    "#fig.suptitle(f\"Articles' similarity filtering articles with ≥{filter_threshold} different keywords\", size='x-large')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0568bc0-9df1-4b0e-9b86-b569af2e3e17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on the number of keyword matches\n",
    "filter_threshold = 0\n",
    "\n",
    "# Sort for economy subplot\n",
    "sim_df_economy = sim_df_parl[sim_df_parl['econ_count'] >= filter_threshold]\n",
    "sim_df_fairness = sim_df_parl[sim_df_parl['fair_count'] >= filter_threshold]\n",
    "sim_df_risktime = sim_df_parl[sim_df_parl['risktime_count'] >= filter_threshold]\n",
    "sim_df_process = sim_df_parl[sim_df_parl['proc_count'] >= filter_threshold]\n",
    "sim_df\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12,8))\n",
    "\n",
    "axs[0,0].hist(sim_df_economy['economy_sim_weighted'], bins=np.arange(min(sim_df_economy['economy_sim_weighted']), max(sim_df_economy['economy_sim_weighted']) + 0.001, 0.001), color='dodgerblue')\n",
    "axs[0,1].hist(sim_df_fairness['fairness_sim_weighted'], bins=np.arange(min(sim_df_fairness['fairness_sim_weighted']), max(sim_df_fairness['fairness_sim_weighted']) + 0.001, 0.001), color='orange')  \n",
    "axs[1,0].hist(sim_df_risktime['risktime_sim_weighted'], bins=np.arange(min(sim_df_risktime['risktime_sim_weighted']), max(sim_df_risktime['risktime_sim_weighted']) + 0.001, 0.001), color='seagreen')\n",
    "axs[1,1].hist(sim_df_process['process_sim_weighted'], bins=np.arange(min(sim_df_process['process_sim_weighted']), max(sim_df_process['process_sim_weighted']) + 0.001, 0.001), color='coral')\n",
    "\n",
    "# Add titles to each subplot\n",
    "axs[0, 0].set_title(f'Economy (N = {len(sim_df_economy)})')\n",
    "axs[0, 1].set_title(f'Fairness (N = {len(sim_df_fairness)})')\n",
    "axs[1, 0].set_title(f'Risk & Time (N = {len(sim_df_risktime)})')\n",
    "axs[1, 1].set_title(f'Process (N = {len(sim_df_process)})')\n",
    "\n",
    "# Add info about stats\n",
    "economy_mean = np.mean(sim_df_economy['economy_sim_weighted'])\n",
    "economy_median = np.median(sim_df_economy['economy_sim_weighted'])\n",
    "axs[0,0].text(0.95, 0.95, f'Mean: {economy_mean:.3f}\\nMedian: {economy_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,0].transAxes)\n",
    "\n",
    "fairness_mean = np.mean(sim_df_fairness['fairness_sim_weighted'])\n",
    "fairness_median = np.median(sim_df_fairness['fairness_sim_weighted'])\n",
    "axs[0,1].text(0.95, 0.95, f'Mean: {fairness_mean:.3f}\\nMedian: {fairness_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,1].transAxes)\n",
    "\n",
    "risktime_mean = np.mean(sim_df_risktime['risktime_sim_weighted'])\n",
    "risktime_median = np.median(sim_df_risktime['risktime_sim_weighted'])\n",
    "axs[1,0].text(0.95, 0.95, f'Mean: {risktime_mean:.3f}\\nMedian: {risktime_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,0].transAxes)\n",
    "\n",
    "process_mean = np.mean(sim_df_process['process_sim_weighted'])\n",
    "process_median = np.median(sim_df_process['process_sim_weighted'])\n",
    "axs[1,1].text(0.95, 0.95, f'Mean: {process_mean:.3f}\\nMedian: {process_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,1].transAxes)\n",
    "\n",
    "# Add a general title to the figure\n",
    "fig.suptitle(f\"Speeches' similarity filtering articles with ≥{filter_threshold} different keywords\", size='x-large')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2e2d997-a5e5-4022-a593-4724d504bcef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on the number of keyword matches\n",
    "filter_threshold = 0\n",
    "\n",
    "# Sort for economy subplot\n",
    "sim_df_economy = sim_df_parl[sim_df_parl['econ_count'] >= filter_threshold]\n",
    "sim_df_fairness = sim_df_parl[sim_df_parl['fair_count'] >= filter_threshold]\n",
    "sim_df_risktime = sim_df_parl[sim_df_parl['risktime_count'] >= filter_threshold]\n",
    "sim_df_process = sim_df_parl[sim_df_parl['proc_count'] >= filter_threshold]\n",
    "sim_df\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12,8))\n",
    "\n",
    "axs[0,0].hist(sim_df_economy['economy_sim_weighted'], bins=np.arange(min(sim_df_economy['economy_sim_weighted']), max(sim_df_economy['economy_sim_weighted']) + 0.001, 0.001), color='dodgerblue')\n",
    "axs[0,1].hist(sim_df_fairness['fairness_sim_weighted'], bins=np.arange(min(sim_df_fairness['fairness_sim_weighted']), max(sim_df_fairness['fairness_sim_weighted']) + 0.001, 0.001), color='orange')  \n",
    "axs[1,0].hist(sim_df_risktime['risktime_sim_weighted'], bins=np.arange(min(sim_df_risktime['risktime_sim_weighted']), max(sim_df_risktime['risktime_sim_weighted']) + 0.001, 0.001), color='seagreen')\n",
    "axs[1,1].hist(sim_df_process['process_sim_weighted'], bins=np.arange(min(sim_df_process['process_sim_weighted']), max(sim_df_process['process_sim_weighted']) + 0.001, 0.001), color='coral')\n",
    "\n",
    "# Add titles to each subplot\n",
    "axs[0, 0].set_title(f'Economy')\n",
    "axs[0, 1].set_title(f'Fairness')\n",
    "axs[1, 0].set_title(f'Risk & Time')\n",
    "axs[1, 1].set_title(f'Process')\n",
    "\n",
    "# Add info about stats\n",
    "economy_mean = np.mean(sim_df_economy['economy_sim_weighted'])\n",
    "economy_median = np.median(sim_df_economy['economy_sim_weighted'])\n",
    "axs[0,0].text(0.95, 0.95, f'Mean: {economy_mean:.3f}\\nMedian: {economy_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,0].transAxes)\n",
    "\n",
    "fairness_mean = np.mean(sim_df_fairness['fairness_sim_weighted'])\n",
    "fairness_median = np.median(sim_df_fairness['fairness_sim_weighted'])\n",
    "axs[0,1].text(0.95, 0.95, f'Mean: {fairness_mean:.3f}\\nMedian: {fairness_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0,1].transAxes)\n",
    "\n",
    "risktime_mean = np.mean(sim_df_risktime['risktime_sim_weighted'])\n",
    "risktime_median = np.median(sim_df_risktime['risktime_sim_weighted'])\n",
    "axs[1,0].text(0.95, 0.95, f'Mean: {risktime_mean:.3f}\\nMedian: {risktime_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,0].transAxes)\n",
    "\n",
    "process_mean = np.mean(sim_df_process['process_sim_weighted'])\n",
    "process_median = np.median(sim_df_process['process_sim_weighted'])\n",
    "axs[1,1].text(0.95, 0.95, f'Mean: {process_mean:.3f}\\nMedian: {process_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1,1].transAxes)\n",
    "\n",
    "# Add a general title to the figure\n",
    "#fig.suptitle(f\"Speeches' similarity filtering articles with ≥{filter_threshold} different keywords\", size='x-large')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b2e11ad-b051-417f-be77-f4e23bcbb08e",
   "metadata": {},
   "source": [
    "## Min-max normalisation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3aa974d1-3504-4108-b127-df76347137f8",
   "metadata": {},
   "source": [
    "### Media"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0646583-9eb6-4a46-b19b-3d85e7e812a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalize the distributions\n",
    "sim_df_media_economy_norm = sim_df_media.copy()\n",
    "sim_df_media_fairness_norm = sim_df_media.copy()\n",
    "sim_df_media_risktime_norm = sim_df_media.copy()\n",
    "sim_df_media_process_norm = sim_df_media.copy()\n",
    "\n",
    "sim_df_media_economy_norm['economy_sim_weighted'] = (sim_df_media_economy_norm['economy_sim_weighted'] - sim_df_media_economy_norm['economy_sim_weighted'].min()) / (sim_df_media_economy_norm['economy_sim_weighted'].max() - sim_df_media_economy_norm['economy_sim_weighted'].min())\n",
    "sim_df_media_fairness_norm['fairness_sim_weighted'] = (sim_df_media_fairness_norm['fairness_sim_weighted'] - sim_df_media_fairness_norm['fairness_sim_weighted'].min()) / (sim_df_media_fairness_norm['fairness_sim_weighted'].max() - sim_df_media_fairness_norm['fairness_sim_weighted'].min())\n",
    "sim_df_media_risktime_norm['risktime_sim_weighted'] = (sim_df_media_risktime_norm['risktime_sim_weighted'] - sim_df_media_risktime_norm['risktime_sim_weighted'].min()) / (sim_df_media_risktime_norm['risktime_sim_weighted'].max() - sim_df_media_risktime_norm['risktime_sim_weighted'].min())\n",
    "sim_df_media_process_norm['process_sim_weighted'] = (sim_df_media_process_norm['process_sim_weighted'] - sim_df_media_process_norm['process_sim_weighted'].min()) / (sim_df_media_process_norm['process_sim_weighted'].max() - sim_df_media_process_norm['process_sim_weighted'].min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "276b39c0-4154-43d3-9d1c-0ff1c8f51dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on the number of keyword matches\n",
    "filter_threshold = 0\n",
    "\n",
    "# Sort for economy subplot\n",
    "sim_df_economy = sim_df_media_economy_norm[sim_df_media_economy_norm['econ_count'] >= filter_threshold]\n",
    "sim_df_fairness = sim_df_media_fairness_norm[sim_df_media_fairness_norm['fair_count'] >= filter_threshold]\n",
    "sim_df_risktime = sim_df_media_risktime_norm[sim_df_media_risktime_norm['risktime_count'] >= filter_threshold]\n",
    "sim_df_process = sim_df_media_process_norm[sim_df_media_process_norm['proc_count'] >= filter_threshold]\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12, 8))\n",
    "\n",
    "axs[0, 0].hist(sim_df_economy['economy_sim_weighted'], bins=np.linspace(0, 1, 101), color='dodgerblue')\n",
    "axs[0, 1].hist(sim_df_fairness['fairness_sim_weighted'], bins=np.linspace(0, 1, 101), color='orange')\n",
    "axs[1, 0].hist(sim_df_risktime['risktime_sim_weighted'], bins=np.linspace(0, 1, 101), color='seagreen')\n",
    "axs[1, 1].hist(sim_df_process['process_sim_weighted'], bins=np.linspace(0, 1, 101), color='coral')\n",
    "\n",
    "# Add titles to each subplot\n",
    "axs[0, 0].set_title(f'Economy (N = {len(sim_df_economy)})')\n",
    "axs[0, 1].set_title(f'Fairness (N = {len(sim_df_fairness)})')\n",
    "axs[1, 0].set_title(f'Risk & Time (N = {len(sim_df_risktime)})')\n",
    "axs[1, 1].set_title(f'Process (N = {len(sim_df_process)})')\n",
    "\n",
    "# Add info about stats\n",
    "economy_mean = np.mean(sim_df_economy['economy_sim_weighted'])\n",
    "economy_median = np.median(sim_df_economy['economy_sim_weighted'])\n",
    "axs[0, 0].text(0.95, 0.95, f'Mean: {economy_mean:.3f}\\nMedian: {economy_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0, 0].transAxes)\n",
    "\n",
    "fairness_mean = np.mean(sim_df_fairness['fairness_sim_weighted'])\n",
    "fairness_median = np.median(sim_df_fairness['fairness_sim_weighted'])\n",
    "axs[0, 1].text(0.95, 0.95, f'Mean: {fairness_mean:.3f}\\nMedian: {fairness_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0, 1].transAxes)\n",
    "\n",
    "risktime_mean = np.mean(sim_df_risktime['risktime_sim_weighted'])\n",
    "risktime_median = np.median(sim_df_risktime['risktime_sim_weighted'])\n",
    "axs[1, 0].text(0.95, 0.95, f'Mean: {risktime_mean:.3f}\\nMedian: {risktime_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1, 0].transAxes)\n",
    "\n",
    "process_mean = np.mean(sim_df_process['process_sim_weighted'])\n",
    "process_median = np.median(sim_df_process['process_sim_weighted'])\n",
    "axs[1, 1].text(0.95, 0.95, f'Mean: {process_mean:.3f}\\nMedian: {process_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1, 1].transAxes)\n",
    "\n",
    "# Add a general title to the figure\n",
    "fig.suptitle(f\"Articles' similarity filtering articles with ≥{filter_threshold} different keywords\", size='x-large')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f06c1718-4c7e-4e1e-9b73-2c0cb97baeca",
   "metadata": {},
   "source": [
    "### Parliament"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d0113e-57f1-4140-9365-76fce42bb1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalize the distributions\n",
    "sim_df_parl_economy_norm = sim_df_parl.copy()\n",
    "sim_df_parl_fairness_norm = sim_df_parl.copy()\n",
    "sim_df_parl_risktime_norm = sim_df_parl.copy()\n",
    "sim_df_parl_process_norm = sim_df_parl.copy()\n",
    "\n",
    "sim_df_parl_economy_norm['economy_sim_weighted'] = (sim_df_parl_economy_norm['economy_sim_weighted'] - sim_df_parl_economy_norm['economy_sim_weighted'].min()) / (sim_df_parl_economy_norm['economy_sim_weighted'].max() - sim_df_parl_economy_norm['economy_sim_weighted'].min())\n",
    "sim_df_parl_fairness_norm['fairness_sim_weighted'] = (sim_df_parl_fairness_norm['fairness_sim_weighted'] - sim_df_parl_fairness_norm['fairness_sim_weighted'].min()) / (sim_df_parl_fairness_norm['fairness_sim_weighted'].max() - sim_df_parl_fairness_norm['fairness_sim_weighted'].min())\n",
    "sim_df_parl_risktime_norm['risktime_sim_weighted'] = (sim_df_parl_risktime_norm['risktime_sim_weighted'] - sim_df_parl_risktime_norm['risktime_sim_weighted'].min()) / (sim_df_parl_risktime_norm['risktime_sim_weighted'].max() - sim_df_parl_risktime_norm['risktime_sim_weighted'].min())\n",
    "sim_df_parl_process_norm['process_sim_weighted'] = (sim_df_parl_process_norm['process_sim_weighted'] - sim_df_parl_process_norm['process_sim_weighted'].min()) / (sim_df_parl_process_norm['process_sim_weighted'].max() - sim_df_parl_process_norm['process_sim_weighted'].min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5058b3c9-4db2-4722-b89a-2879e13fed38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on the number of keyword matches\n",
    "filter_threshold = 0\n",
    "\n",
    "# Sort for economy subplot\n",
    "sim_df_economy = sim_df_parl_economy_norm[sim_df_parl_economy_norm['econ_count'] >= filter_threshold]\n",
    "sim_df_fairness = sim_df_parl_fairness_norm[sim_df_parl_fairness_norm['fair_count'] >= filter_threshold]\n",
    "sim_df_risktime = sim_df_parl_risktime_norm[sim_df_parl_risktime_norm['risktime_count'] >= filter_threshold]\n",
    "sim_df_process = sim_df_parl_process_norm[sim_df_parl_process_norm['proc_count'] >= filter_threshold]\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(12, 8))\n",
    "\n",
    "axs[0, 0].hist(sim_df_economy['economy_sim_weighted'], bins=np.linspace(0, 1, 101), color='dodgerblue')\n",
    "axs[0, 1].hist(sim_df_fairness['fairness_sim_weighted'], bins=np.linspace(0, 1, 101), color='orange')\n",
    "axs[1, 0].hist(sim_df_risktime['risktime_sim_weighted'], bins=np.linspace(0, 1, 101), color='seagreen')\n",
    "axs[1, 1].hist(sim_df_process['process_sim_weighted'], bins=np.linspace(0, 1, 101), color='coral')\n",
    "\n",
    "# Add titles to each subplot\n",
    "axs[0, 0].set_title(f'Economy (N = {len(sim_df_economy)})')\n",
    "axs[0, 1].set_title(f'Fairness (N = {len(sim_df_fairness)})')\n",
    "axs[1, 0].set_title(f'Risk & Time (N = {len(sim_df_risktime)})')\n",
    "axs[1, 1].set_title(f'Process (N = {len(sim_df_process)})')\n",
    "\n",
    "# Add info about stats\n",
    "economy_mean = np.mean(sim_df_economy['economy_sim_weighted'])\n",
    "economy_median = np.median(sim_df_economy['economy_sim_weighted'])\n",
    "axs[0, 0].text(0.95, 0.95, f'Mean: {economy_mean:.3f}\\nMedian: {economy_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0, 0].transAxes)\n",
    "\n",
    "fairness_mean = np.mean(sim_df_fairness['fairness_sim_weighted'])\n",
    "fairness_median = np.median(sim_df_fairness['fairness_sim_weighted'])\n",
    "axs[0, 1].text(0.95, 0.95, f'Mean: {fairness_mean:.3f}\\nMedian: {fairness_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[0, 1].transAxes)\n",
    "\n",
    "risktime_mean = np.mean(sim_df_risktime['risktime_sim_weighted'])\n",
    "risktime_median = np.median(sim_df_risktime['risktime_sim_weighted'])\n",
    "axs[1, 0].text(0.95, 0.95, f'Mean: {risktime_mean:.3f}\\nMedian: {risktime_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1, 0].transAxes)\n",
    "\n",
    "process_mean = np.mean(sim_df_process['process_sim_weighted'])\n",
    "process_median = np.median(sim_df_process['process_sim_weighted'])\n",
    "axs[1, 1].text(0.95, 0.95, f'Mean: {process_mean:.3f}\\nMedian: {process_median:.3f}', horizontalalignment='right', verticalalignment='top', transform=axs[1, 1].transAxes)\n",
    "\n",
    "# Add a general title to the figure\n",
    "fig.suptitle(f\"Speeches' similarity filtering articles with ≥{filter_threshold} different keywords\", size='x-large')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecea280b-1aa2-4638-8227-3eccc3730470",
   "metadata": {},
   "source": [
    "## Slicing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d2bd90-dc51-4348-88f9-cbde57e6963d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filtering on common similarity threshold\n",
    "shared_sim_threshold = 0.9\n",
    "\n",
    "# Sort for economy subplot\n",
    "top_economy_parl = sim_df_parl_economy_norm[sim_df_parl_economy_norm['economy_sim_weighted'] >= shared_sim_threshold]\n",
    "top_fairness_parl = sim_df_parl_fairness_norm[sim_df_parl_fairness_norm['fairness_sim_weighted'] >= shared_sim_threshold]\n",
    "top_risktime_parl = sim_df_parl_risktime_norm[sim_df_parl_risktime_norm['risktime_sim_weighted'] >= shared_sim_threshold]\n",
    "top_process_parl = sim_df_parl_process_norm[sim_df_parl_process_norm['process_sim_weighted'] >= shared_sim_threshold]\n",
    "\n",
    "# Prepare the four corpora\n",
    "corpus_economy_parl = top_economy_parl['article_processed']\n",
    "corpus_fairness_parl = top_fairness_parl['article_processed']\n",
    "corpus_risktime_parl = top_risktime_parl['article_processed']\n",
    "corpus_process_parl = top_process_parl['article_processed']\n",
    "\n",
    "print('The number of PARLIAMENTARY SPEECHES for each corpus is:' +\n",
    "      '\\n - Economic dimension: ' + str(len(corpus_economy_parl)) +\n",
    "      '\\n - Fairness dimension: ' + str(len(corpus_fairness_parl)) +\n",
    "      '\\n - Risk & Time dimension: ' + str(len(corpus_risktime_parl)) +\n",
    "      '\\n - Process dimension: ' + str(len(corpus_process_parl)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c3084c6-2780-441d-8f2f-cd340203f549",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sort for economy subplot\n",
    "top_economy_media = sim_df_media_economy_norm[sim_df_media_economy_norm['economy_sim_weighted'] >= shared_sim_threshold]\n",
    "top_fairness_media = sim_df_media_fairness_norm[sim_df_media_fairness_norm['fairness_sim_weighted'] >= shared_sim_threshold]\n",
    "top_risktime_media = sim_df_media_risktime_norm[sim_df_media_risktime_norm['risktime_sim_weighted'] >= shared_sim_threshold]\n",
    "top_process_media = sim_df_media_process_norm[sim_df_media_process_norm['process_sim_weighted'] >= shared_sim_threshold]\n",
    "\n",
    "# Prepare the four corpora\n",
    "corpus_economy_media = top_economy_media['article_processed']\n",
    "corpus_fairness_media = top_fairness_media['article_processed']\n",
    "corpus_risktime_media = top_risktime_media['article_processed']\n",
    "corpus_process_media = top_process_media['article_processed']\n",
    "\n",
    "print('The number of MEDIA ARTICLES for each corpus is:' +\n",
    "      '\\n - Economic dimension: ' + str(len(corpus_economy_media)) +\n",
    "      '\\n - Fairness dimension: ' + str(len(corpus_fairness_media)) +\n",
    "      '\\n - Risk & Time dimension: ' + str(len(corpus_risktime_media)) +\n",
    "      '\\n - Process dimension: ' + str(len(corpus_process_media)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6033fda-8947-4d63-8f70-bf2f49074b4b",
   "metadata": {},
   "source": [
    "# Time series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75e6347a-30dd-4214-8e60-45b0b01fe225",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resampling to transform data into time series\n",
    "econ_media_weekly = top_economy_media.resample('W', on='Date').count()\n",
    "fair_media_weekly = top_fairness_media.resample('W', on='Date').count()\n",
    "risk_media_weekly = top_risktime_media.resample('W', on='Date').count()\n",
    "proc_media_weekly = top_process_media.resample('W', on='Date').count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13fe0ddd-fabb-4558-b7fc-d648871032b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new figure with subplots\n",
    "fig, ax = plt.subplots(figsize=(10, 5))\n",
    "\n",
    "# Plot dimensions' data\n",
    "ax.plot(econ_media_weekly.index, econ_media_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Economy (N = {len(top_economy_media)})', color='dodgerblue')\n",
    "ax.plot(fair_media_weekly.index, fair_media_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Fairness (N = {len(top_fairness_media)})', color='orange')\n",
    "ax.plot(risk_media_weekly.index, risk_media_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Risk & Time (N = {len(top_risktime_media)})', color='seagreen')\n",
    "ax.plot(proc_media_weekly.index, proc_media_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Process (N = {len(top_process_media)})', color='coral')\n",
    "\n",
    "# Set axis labels and title\n",
    "ax.set_xlabel('Date')\n",
    "ax.set_ylabel(f'Weekly N of articles')\n",
    "ax.set_title(f'Dimensions media coverage (documents with similarity > {shared_sim_threshold})')\n",
    "\n",
    "# Show legend\n",
    "ax.legend()\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "983760b8-cbc0-4c3b-9c5a-7d0d4a4db03f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resampling to transform data into time series\n",
    "econ_parl_weekly = top_economy_parl.resample('W', on='Date').count()\n",
    "fair_parl_weekly = top_fairness_parl.resample('W', on='Date').count()\n",
    "risk_parl_weekly = top_risktime_parl.resample('W', on='Date').count()\n",
    "proc_parl_weekly = top_process_parl.resample('W', on='Date').count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a21e0203-783e-4567-a468-d1c6aacd7743",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new figure with subplots\n",
    "fig, ax = plt.subplots(figsize=(10, 5))\n",
    "\n",
    "# Plot dimensions' data\n",
    "ax.plot(econ_parl_weekly.index, econ_parl_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Economy (N = {len(top_economy_parl)})', color='dodgerblue')\n",
    "ax.plot(fair_parl_weekly.index, fair_parl_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Fairness (N = {len(top_fairness_parl)})', color='orange')\n",
    "ax.plot(risk_parl_weekly.index, risk_parl_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Risk & Time (N = {len(top_risktime_parl)})', color='seagreen')\n",
    "ax.plot(proc_parl_weekly.index, proc_parl_weekly['Source'].rolling(3, min_periods=0).mean(), label=f'Process (N = {len(top_process_parl)})', color='coral')\n",
    "\n",
    "# Set axis labels and title\n",
    "ax.set_xlabel('Date')\n",
    "ax.set_ylabel(f'Weekly N of articles')\n",
    "ax.set_title(f'Dimensions media coverage (documents with similarity > {shared_sim_threshold})')\n",
    "\n",
    "# Show legend\n",
    "ax.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f381c751-f544-480d-ae8b-8775b28b6e36",
   "metadata": {},
   "source": [
    "## Normalisation by number of articles published each unit of time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13b5fc05-1734-42ef-b15f-4acae9cf0f89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resample to obtain time series\n",
    "econ_media_source = top_economy_media.resample('W', on='Date').count()['Source']\n",
    "fair_media_source = top_fairness_media.resample('W', on='Date').count()['Source']\n",
    "risk_media_source = top_risktime_media.resample('W', on='Date').count()['Source']\n",
    "proc_media_source = top_process_media.resample('W', on='Date').count()['Source']\n",
    "\n",
    "documents_media_daily = sim_df_media.resample('W', on='Date').count()['Source']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96eed27c-4035-4d21-b579-a69fd55abf5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate the extracted columns along columns axis\n",
    "weekly_data_media = pd.concat([documents_media_daily, econ_media_source, fair_media_source, risk_media_source, proc_media_source], axis=1)\n",
    "weekly_data_media.columns = ['N_Documents', 'Economy', 'Fairness', 'RiskTime', 'Process']\n",
    "weekly_data_media.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39eb9f2a-1de5-413e-a5d7-60826df3ecbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divide each dimension's time series by the total N of docs published on that unit of time\n",
    "weekly_data_media['Economy_norm'] = weekly_data_media['Economy'] / weekly_data_media['N_Documents']\n",
    "weekly_data_media['Fairness_norm'] = weekly_data_media['Fairness'] / weekly_data_media['N_Documents']\n",
    "weekly_data_media['RiskTime_norm'] = weekly_data_media['RiskTime'] / weekly_data_media['N_Documents']\n",
    "weekly_data_media['Process_norm'] = weekly_data_media['Process'] / weekly_data_media['N_Documents']\n",
    "\n",
    "weekly_data_media.fillna(0, inplace=True)\n",
    "weekly_data_media.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13eec81f-0a21-4364-9c80-bf050f735636",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cut the last week because we don't have reliable data\n",
    "weekly_data_media_cut = weekly_data_media.iloc[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b885935-062f-41b8-9f50-168e26c6a784",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define important dates \n",
    "important_dates = {\n",
    "    datetime.datetime(2022, 12, 31): 'Macron\\nannounces\\nthe reform',\n",
    "    datetime.datetime(2023, 1, 30): 'Law is\\nintroduced\\nin AN',\n",
    "    datetime.datetime(2023, 3, 16): 'PM invokes\\nArt. 49.3',\n",
    "    datetime.datetime(2023, 4, 14): 'Consitutional\\nCouncil ratifies\\nthe law',\n",
    "}\n",
    "\n",
    "# Plot the normalized time series\n",
    "plt.figure(figsize=(12, 4))\n",
    "ax = plt.gca()  # Get the current Axes instance\n",
    "\n",
    "plt.plot(weekly_data_media_cut.index, weekly_data_media_cut['Economy_norm'].rolling(2).mean(), label=f\"Economy (N = {int(sum(weekly_data_media_cut['Economy']))})\", color='dodgerblue')\n",
    "plt.plot(weekly_data_media_cut.index, weekly_data_media_cut['Fairness_norm'].rolling(2).mean(), label=f\"Fairness (N = {int(sum(weekly_data_media_cut['Fairness']))})\", color='orange')\n",
    "plt.plot(weekly_data_media_cut.index, weekly_data_media_cut['RiskTime_norm'].rolling(2).mean(), label=f\"Risk & Time (N = {int(sum(weekly_data_media_cut['RiskTime']))})\", color='seagreen')\n",
    "plt.plot(weekly_data_media_cut.index, weekly_data_media_cut['Process_norm'].rolling(2).mean(), label=f\"Process (N = {int(sum(weekly_data_media_cut['Process']))})\", color='coral') \n",
    "\n",
    "plt.xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "plt.ylabel('Normalized count')\n",
    "#plt.ylim(0, 0.19)\n",
    "#plt.title('Weekly number of media articles for each dimension\\n(normalized over total articles published each week)')\n",
    "#ax.legend(loc='upper right')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=0.1)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "#plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90cfa264-5db5-4945-ada2-fe298faa8ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resample to obtain time series\n",
    "econ_parl_source = top_economy_parl.resample('W', on='Date').count()['Source']\n",
    "fair_parl_source = top_fairness_parl.resample('W', on='Date').count()['Source']\n",
    "risk_parl_source = top_risktime_parl.resample('W', on='Date').count()['Source']\n",
    "proc_parl_source = top_process_parl.resample('W', on='Date').count()['Source']\n",
    "\n",
    "documents_parl_daily = sim_df_parl.resample('W', on='Date').count()['Source']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3830897-e6e7-4172-a2e2-f214f8263d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate the extracted columns along columns axis\n",
    "weekly_data_parl = pd.concat([documents_parl_daily, econ_parl_source, fair_parl_source, risk_parl_source, proc_parl_source], axis=1)\n",
    "weekly_data_parl.columns = ['N_Documents', 'Economy', 'Fairness', 'RiskTime', 'Process']\n",
    "weekly_data_parl.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a484a60e-a364-4d87-8e09-18d6f3ab1f90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divide each dimension's time series by the total N of docs published on that unit of time\n",
    "weekly_data_parl['Economy_norm'] = weekly_data_parl['Economy'] / weekly_data_parl['N_Documents']\n",
    "weekly_data_parl['Fairness_norm'] = weekly_data_parl['Fairness'] / weekly_data_parl['N_Documents']\n",
    "weekly_data_parl['RiskTime_norm'] = weekly_data_parl['RiskTime'] / weekly_data_parl['N_Documents']\n",
    "weekly_data_parl['Process_norm'] = weekly_data_parl['Process'] / weekly_data_parl['N_Documents']\n",
    "\n",
    "weekly_data_parl.fillna(0, inplace=True)\n",
    "weekly_data_parl.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89b6542d-94cf-4dd5-8cb9-b26acff03311",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the normalized time series\n",
    "plt.figure(figsize=(12, 4))\n",
    "ax = plt.gca()  # Get the current Axes instance\n",
    "\n",
    "plt.plot(weekly_data_parl.index, weekly_data_parl['Economy_norm'].rolling(2).mean(), label=f\"Economy (N = {int(sum(weekly_data_parl['Economy']))})\", color='dodgerblue')\n",
    "plt.plot(weekly_data_parl.index, weekly_data_parl['Fairness_norm'].rolling(2).mean(), label=f\"Fairness (N = {int(sum(weekly_data_parl['Fairness']))})\", color='orange')\n",
    "plt.plot(weekly_data_parl.index, weekly_data_parl['RiskTime_norm'].rolling(2).mean(), label=f\"Risk & Time (N = {int(sum(weekly_data_parl['RiskTime']))})\", color='seagreen')\n",
    "plt.plot(weekly_data_parl.index, weekly_data_parl['Process_norm'].rolling(2).mean(), label=f\"Process (N = {int(sum(weekly_data_parl['Process']))})\", color='coral')\n",
    "\n",
    "plt.xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "plt.ylabel('Normalized count')\n",
    "#plt.ylim(0, 0.5)\n",
    "#plt.title('Weekly number of media articles for each dimension\\n(normalized over total articles published each week)')\n",
    "#ax.legend(loc='upper right')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, color='gray', linestyle=':', alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad470ef1-8091-492c-8c3e-9056e33a9d97",
   "metadata": {},
   "source": [
    "# Topic modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42d41d30-7a1c-44d2-8b91-a84f1bacbe71",
   "metadata": {},
   "source": [
    "## Media"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8f3a76f-aa6c-4ade-8f78-128659b70d33",
   "metadata": {},
   "source": [
    "### Economy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95b7acd4-5722-4c26-9067-35bea7295b8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize variables which will be used later\n",
    "n_features = 5000\n",
    "n_components = 10 \n",
    "n_top_words = 15 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aff09d9-dc67-47d1-9082-cd4d4d4f5fb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the TF-IDF matrix\n",
    "tfidf_vectorizer = TfidfVectorizer(\n",
    "    max_df=0.95, min_df=5,  \n",
    "    max_features=n_features, \n",
    "    ngram_range=(1, 1))\n",
    "\n",
    "tfidf_economy = tfidf_vectorizer.fit_transform(corpus_economy_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0ffe8a-1e7c-49c4-8385-dbbd900fbe37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform NMF\n",
    "seed = 42\n",
    "nmf_economy = NMF(n_components=n_components,\n",
    "                  random_state=seed, \n",
    "                  alpha_W=0, # Constant that multiplies the regularization terms of W. Set it to zero (default) to have no regularization on W\n",
    "                  l1_ratio=0, # The regularization mixing parameter 0-1. l1_ratio = 0, penalty is an elementwise L2 penalty (aka Frobenius Norm). For l1_ratio = 1 it is an elementwise L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.\n",
    "                  max_iter=1000,\n",
    "                  init='random'\n",
    "                  ).fit(tfidf_economy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be107e6b-1c4a-4851-bf9a-20f02c10d1fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to plot the top words for each topic from a fitted topic model\n",
    "def plot_top_words(model, feature_names, n_top_words, title, name, n_topics, n_seed):\n",
    "    fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)\n",
    "    axes = axes.flatten()\n",
    "    for topic_idx, topic in enumerate(model.components_):\n",
    "        top_features_ind = topic.argsort()[: -n_top_words - 1 : -1]\n",
    "        top_features = [feature_names[i] for i in top_features_ind]\n",
    "        weights = topic[top_features_ind]\n",
    "\n",
    "        ax = axes[topic_idx]\n",
    "        ax.barh(top_features, weights, height=0.9)\n",
    "        ax.set_title(f\"Topic {topic_idx +1}\", fontdict={\"fontsize\": 30})\n",
    "        ax.invert_yaxis()\n",
    "        ax.tick_params(axis=\"both\", which=\"major\", labelsize=20)\n",
    "        for i in \"top right left\".split():\n",
    "            ax.spines[i].set_visible(False)\n",
    "        fig.suptitle(title, fontsize=50)\n",
    "\n",
    "    plt.subplots_adjust(top=0.87, bottom=0.05, wspace=0.7, hspace=0.2)\n",
    "    #plt.savefig('./figures/'+str(name)+'_topics'+str(n_topics)+'_seed'+str(n_seed)+'.png')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10aff18f-3823-49e6-b751-eb8e5c2867ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the TF-IDF vectorizer\n",
    "tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Plot the topic modelling\n",
    "plot_top_words(\n",
    "    nmf_economy, tfidf_feature_names, n_top_words, f\"Topics in articles on economic dimension (N = {len(corpus_economy_media)})\", 'scikit_selmedia_nofiltSW', n_components, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96eb3a72-e485-4acf-a26d-1ffa37cdd96c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the vectorizer\n",
    "feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Get the top 3 indices for each topic\n",
    "top_indices = nmf_economy.components_.argsort()[:,-5:]\n",
    "\n",
    "# Create empty list to store top words  \n",
    "top_words_economy = []\n",
    "\n",
    "# Loop through each topic\n",
    "for topic in top_indices:\n",
    "  \n",
    "  # Extract top 3 words\n",
    "  top_3 = [feature_names[i] for i in topic]\n",
    "  \n",
    "  # Add to list\n",
    "  top_words_economy.extend(top_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee82db3-4023-4308-92e1-7eafbd6480df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the document-topic matrix\n",
    "doc_topic_matrix = nmf_economy.transform(tfidf_economy)\n",
    "\n",
    "# Calculate topic frequencies across corpus\n",
    "topic_freqs = doc_topic_matrix.sum(axis=0) \n",
    "\n",
    "# Get top 3 words for each topic\n",
    "top_word_indices = nmf_economy.components_.argsort()[:, -3:]\n",
    "top_words = [[tfidf_feature_names[i] for i in topic[::-1]] for topic in top_word_indices]\n",
    "top_words = [' '.join(words) for words in top_words]\n",
    "\n",
    "# Combine topics and top words\n",
    "topics_with_words = list(zip(topic_freqs, top_words))\n",
    "\n",
    "# Sort by frequency\n",
    "sorted_topics = sorted(topics_with_words, key=lambda x: x[0], reverse=False)\n",
    "\n",
    "# Extract sorted freqs and topic names  \n",
    "sorted_freqs, sorted_top_words = zip(*sorted_topics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f1c8619-5453-4656-b127-291bbea29133",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot sorted topics and freqs  \n",
    "fig, ax = plt.subplots(figsize=(3.5, 6)) \n",
    "\n",
    "# Plot horizontal bars\n",
    "ax.barh(range(n_components), sorted_freqs, color='dodgerblue')\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax.set_yticks(range(n_components)) \n",
    "ax.set_yticklabels(sorted_top_words)\n",
    "\n",
    "# Remove x-axis tick labels  \n",
    "ax.set_xticks([]) \n",
    "\n",
    "# Axis labels\n",
    "ax.set_xlabel(\"Frequency\")\n",
    "\n",
    "# Title and tight layout\n",
    "ax.set_title(f\"Economy\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02762f95-8b6f-4bd6-aa25-d94cd6eb5848",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Document-topic matrix, where columns are the top words\n",
    "w_df = pd.DataFrame(doc_topic_matrix)\n",
    "w_df.columns = top_words\n",
    "w_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c1e07c-880f-4044-ae42-df7531d318a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_df_date = w_df.copy()\n",
    "w_df_date['Date'] = top_economy_media['Date'].values\n",
    "\n",
    "# Convert 'Date' column to datetime and set it as the index\n",
    "w_df_date['Date'] = pd.to_datetime(w_df_date['Date'])\n",
    "w_df_date.set_index('Date', inplace=True)\n",
    "\n",
    "# Resample at weekly level\n",
    "w_df_date = w_df_date.resample('W').sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a36236b7-66cc-49f4-a72e-61c8e3b187ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot\n",
    "colors = sns.color_palette(\"tab20\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12,5))\n",
    "\n",
    "# Plot a stackplot\n",
    "ax.stackplot(w_df_date.index, w_df_date.T, baseline='sym', labels=w_df_date.columns, colors=colors)\n",
    "\n",
    "# Move the legend off of the chart\n",
    "ax.legend(loc=(0,-0.6), ncol=3)\n",
    "\n",
    "plt.title(f'Topic evolution over time in media articles about the economic dimension (N = {len(w_df)})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00780ebf-00e3-4b3a-aaf4-d137eb1f66c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divide each value of the time series by the total weight across all topics\n",
    "w_df_date_norm = w_df_date.copy()\n",
    "w_df_date_norm['weekly_weight_sum'] = w_df_date_norm.apply(np.sum, axis=1)\n",
    "w_df_date_norm.iloc[:, :-1] = w_df_date_norm.iloc[:, :-1].div(w_df_date_norm['weekly_weight_sum'], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035964ed-cbe0-4b83-95ed-565f9f899d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"dette milliard public\",\n",
    "    \"prix production agricole\",\n",
    "    \"milliardaire riche fortune\",\n",
    "    \"assurance fonds vie\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 5))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='dodgerblue', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Economy dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.25), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=0.2)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61dd01b9-dfd1-42b3-965f-97fb5c15f05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"cotisation revenu social\",\n",
    "    \"logement immobilier taxe\",\n",
    "    \"scpi immobilier rendement\"\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='dodgerblue', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Economy dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=0.2)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "858841ed-0284-4ac9-b7a3-ddce5a6b3086",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"per plan versement\",\n",
    "    \"capitalisation pension fonctionnaire\",\n",
    "    \"banque financier dollar\"\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='dodgerblue', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Economy dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=0.2)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d20cd4f-0a1e-4d26-9798-8e6aba1f8155",
   "metadata": {},
   "source": [
    "### Fairness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba5f7754-a7f2-4c0d-99f8-1ed551cc6ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the TF-IDF matrix\n",
    "tfidf_vectorizer = TfidfVectorizer(\n",
    "    max_df=0.95, min_df=5,  \n",
    "    max_features=n_features, \n",
    "    ngram_range=(1, 1))\n",
    "\n",
    "tfidf_fairness = tfidf_vectorizer.fit_transform(corpus_fairness_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "565de621-b8db-4a5c-8877-79eb75b385b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform NMF\n",
    "seed = 42\n",
    "nmf_fairness = NMF(n_components=n_components,\n",
    "                  random_state=seed, # Pass an int for reproducible results across multiple function calls.\n",
    "                  alpha_W=0, # Constant that multiplies the regularization terms of W. Set it to zero (default) to have no regularization on W\n",
    "                  l1_ratio=0, # The regularization mixing parameter 0-1. l1_ratio = 0, penalty is an elementwise L2 penalty (aka Frobenius Norm). For l1_ratio = 1 it is an elementwise L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.\n",
    "                  max_iter=1000,\n",
    "                  init='random'\n",
    "                  ).fit(tfidf_fairness)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3369665-a6e7-4842-a584-12a5a50b6bf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the TF-IDF vectorizer\n",
    "tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "plot_top_words(\n",
    "    nmf_fairness, tfidf_feature_names, n_top_words, f\"Topics in articles on fairness dimension (N = {len(corpus_fairness_media)})\", 'scikit_selmedia_nofiltSW', n_components, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48ebff06-bbde-452a-9b32-9a526a2db835",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the vectorizer\n",
    "feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Get the top 3 indices for each topic\n",
    "top_indices = nmf_fairness.components_.argsort()[:,-5:]\n",
    "\n",
    "# Create empty list to store top words  \n",
    "top_words_fairness = []\n",
    "\n",
    "# Loop through each topic\n",
    "for topic in top_indices:\n",
    "  \n",
    "  # Extract top 3 words\n",
    "  top_3 = [feature_names[i] for i in topic]\n",
    "  \n",
    "  # Add to list\n",
    "  top_words_fairness.extend(top_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7982973-8359-48e2-b4e3-26f8d5000205",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the document-topic matrix\n",
    "doc_topic_matrix = nmf_fairness.transform(tfidf_fairness)\n",
    "\n",
    "# Calculate topic frequencies across corpus\n",
    "topic_freqs = doc_topic_matrix.sum(axis=0) \n",
    "\n",
    "# Get top 3 words for each topic\n",
    "top_word_indices = nmf_fairness.components_.argsort()[:, -3:]\n",
    "top_words = [[tfidf_feature_names[i] for i in topic[::-1]] for topic in top_word_indices]\n",
    "top_words = [' '.join(words) for words in top_words]\n",
    "\n",
    "# Combine topics and top words\n",
    "topics_with_words = list(zip(topic_freqs, top_words))\n",
    "\n",
    "# Sort by frequency\n",
    "sorted_topics = sorted(topics_with_words, key=lambda x: x[0], reverse=False)\n",
    "\n",
    "# Extract sorted freqs and topic names  \n",
    "sorted_freqs, sorted_top_words = zip(*sorted_topics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3426702f-a5ea-404d-bb09-5f88068cbb91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot sorted topics and freqs  \n",
    "fig, ax = plt.subplots(figsize=(3.5, 6)) \n",
    "\n",
    "# Plot horizontal bars\n",
    "ax.barh(range(n_components), sorted_freqs, color='orange')\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax.set_yticks(range(n_components)) \n",
    "ax.set_yticklabels(sorted_top_words)\n",
    "\n",
    "# Remove x-axis tick labels  \n",
    "ax.set_xticks([]) \n",
    "\n",
    "# Axis labels\n",
    "ax.set_xlabel(\"Frequency\")\n",
    "\n",
    "# Title and tight layout\n",
    "ax.set_title(f\"Fairness\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bed5fb-0401-4f5d-81ff-25b32696df01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Document-topic matrix, where columns are the top words\n",
    "w_df = pd.DataFrame(doc_topic_matrix)\n",
    "w_df.columns = top_words\n",
    "w_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8486209-fbeb-42cb-a307-f5e17e6a7f9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_df_date = w_df.copy()\n",
    "w_df_date['Date'] = top_fairness_media['Date'].values\n",
    "\n",
    "# Convert 'Date' column to datetime and set it as the index\n",
    "w_df_date['Date'] = pd.to_datetime(w_df_date['Date'])\n",
    "w_df_date.set_index('Date', inplace=True)\n",
    "\n",
    "# Resample at weekly level\n",
    "w_df_date = w_df_date.resample('W').sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b07ddea-f548-4c2a-a71f-84a66efeb490",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot\n",
    "fig, ax = plt.subplots(figsize=(12,5))\n",
    "\n",
    "# Plot a stackplot\n",
    "ax.stackplot(w_df_date.index, w_df_date.T, baseline='sym', labels=w_df_date.columns, colors=colors)\n",
    "\n",
    "# Move the legend off of the chart\n",
    "ax.legend(loc=(0,-0.6), ncol=3)\n",
    "\n",
    "plt.title(f'Topic evolution over time in media articles about the fairness dimension (N = {len(w_df)})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cc0ab55-316a-42c5-846c-756e57df0a11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divide each value of the time series by the total weight across all topics\n",
    "w_df_date_norm = w_df_date.copy()\n",
    "w_df_date_norm['weekly_weight_sum'] = w_df_date_norm.apply(np.sum, axis=1)\n",
    "w_df_date_norm.iloc[:, :-1] = w_df_date_norm.iloc[:, :-1].div(w_df_date_norm['weekly_weight_sum'], axis=0)\n",
    "\n",
    "w_df_date_norm.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "688e1ebb-9e43-4b53-afbd-1ad46f1c4c46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"femme homme pourcent\",\n",
    "    \"capitalisation pension fonctionnaire\",\n",
    "    \"culte catholique ministre\",\n",
    "    \"exploitation agricole terre\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 5))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='orange', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Fairness dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.25), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2543aec-7b61-4678-a049-0df3e6978d43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"an association don\",\n",
    "    \"per rente plan\",\n",
    "    \"scpi immobilier rendement\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='orange', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Fairness dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab6dce97-2935-4f74-a865-71ad6fd46e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"fonds assurance private\",\n",
    "    \"cotisation emploi social\",\n",
    "    \"pension alimentaire parent\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='orange', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Fairness dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "156a630e-76e8-4838-8345-0e14178038b0",
   "metadata": {},
   "source": [
    "### Risk & Time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83446467-4555-4aa2-9515-ce3d0311d2e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the TF-IDF matrix\n",
    "tfidf_vectorizer = TfidfVectorizer(\n",
    "    max_df=0.95, min_df=5,  \n",
    "    max_features=n_features, \n",
    "    ngram_range=(1, 1))\n",
    "\n",
    "tfidf_risktime = tfidf_vectorizer.fit_transform(corpus_risktime_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bb30641-d33f-43a9-8f56-912a42115c3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform NMF\n",
    "seed = 42\n",
    "nmf_risktime = NMF(n_components=n_components,\n",
    "                  random_state=seed, # Pass an int for reproducible results across multiple function calls.\n",
    "                  alpha_W=0, # Constant that multiplies the regularization terms of W. Set it to zero (default) to have no regularization on W\n",
    "                  l1_ratio=0, # The regularization mixing parameter 0-1. l1_ratio = 0, penalty is an elementwise L2 penalty (aka Frobenius Norm). For l1_ratio = 1 it is an elementwise L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.\n",
    "                  max_iter=1000,\n",
    "                  init='random'\n",
    "                  ).fit(tfidf_risktime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fb5e99d-6f14-45f0-b19e-687c23d6a2f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the TF-IDF vectorizer\n",
    "tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "plot_top_words(\n",
    "    nmf_risktime, tfidf_feature_names, n_top_words, f\"Topics in articles on risk and time dimension (N = {len(corpus_risktime_media)})\", 'scikit_selmedia_nofiltSW', n_components, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b23d456-c671-4a9f-ac78-5ad8cb17b3e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the vectorizer\n",
    "feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Get the top 3 indices for each topic\n",
    "top_indices = nmf_risktime.components_.argsort()[:,-5:]\n",
    "\n",
    "# Create empty list to store top words  \n",
    "top_words_risktime = []\n",
    "\n",
    "# Loop through each topic\n",
    "for topic in top_indices:\n",
    "  \n",
    "  # Extract top 3 words\n",
    "  top_3 = [feature_names[i] for i in topic]\n",
    "  \n",
    "  # Add to list\n",
    "  top_words_risktime.extend(top_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "641c9939-4c74-4611-a583-7cf5658f6a9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the document-topic matrix\n",
    "doc_topic_matrix = nmf_risktime.transform(tfidf_risktime)\n",
    "\n",
    "# Calculate topic frequencies across corpus\n",
    "topic_freqs = doc_topic_matrix.sum(axis=0) \n",
    "\n",
    "# Get top 3 words for each topic\n",
    "top_word_indices = nmf_risktime.components_.argsort()[:, -3:]\n",
    "top_words = [[tfidf_feature_names[i] for i in topic[::-1]] for topic in top_word_indices]\n",
    "top_words = [' '.join(words) for words in top_words]\n",
    "\n",
    "# Combine topics and top words\n",
    "topics_with_words = list(zip(topic_freqs, top_words))\n",
    "\n",
    "# Sort by frequency\n",
    "sorted_topics = sorted(topics_with_words, key=lambda x: x[0], reverse=False)\n",
    "\n",
    "# Extract sorted freqs and topic names  \n",
    "sorted_freqs, sorted_top_words = zip(*sorted_topics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf6b6f78-56d9-4767-9698-c116fa053cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot sorted topics and freqs  \n",
    "fig, ax = plt.subplots(figsize=(3.5, 6)) \n",
    "\n",
    "# Plot horizontal bars\n",
    "ax.barh(range(n_components), sorted_freqs, color='seagreen')\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax.set_yticks(range(n_components)) \n",
    "ax.set_yticklabels(sorted_top_words)\n",
    "\n",
    "# Remove x-axis tick labels  \n",
    "ax.set_xticks([]) \n",
    "\n",
    "# Axis labels\n",
    "ax.set_xlabel(\"Frequency\")\n",
    "\n",
    "# Title and tight layout\n",
    "ax.set_title(f\"Risk & Time\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66600ffa-7c7d-4427-80fd-9ccf8a8569fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Document-topic matrix, where columns are the top words\n",
    "w_df = pd.DataFrame(doc_topic_matrix)\n",
    "w_df.columns = top_words\n",
    "w_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b55b5b39-8064-4eaa-84f8-5c15ba8399d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_df_date = w_df.copy()\n",
    "w_df_date['Date'] = top_risktime_media['Date'].values\n",
    "\n",
    "# Convert 'Date' column to datetime and set it as the index\n",
    "w_df_date['Date'] = pd.to_datetime(w_df_date['Date'])\n",
    "w_df_date.set_index('Date', inplace=True)\n",
    "\n",
    "# Resample at weekly level\n",
    "w_df_date = w_df_date.resample('W').sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0824f0d2-ee41-4800-974b-2da90cd735c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot\n",
    "fig, ax = plt.subplots(figsize=(12,5))\n",
    "\n",
    "# Plot a stackplot\n",
    "ax.stackplot(w_df_date.index, w_df_date.T, baseline='sym', labels=w_df_date.columns, colors=colors)\n",
    "\n",
    "# Move the legend off of the chart\n",
    "ax.legend(loc=(0,-0.6), ncol=3)\n",
    "\n",
    "plt.title(f'Topic evolution over time in media articles about the risk & time dimension (N = {len(w_df)})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "195e0380-79d5-4f5e-8114-129b42556f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divide each value of the time series by the total weight across all topics\n",
    "w_df_date_norm = w_df_date.copy()\n",
    "w_df_date_norm['weekly_weight_sum'] = w_df_date_norm.apply(np.sum, axis=1)\n",
    "w_df_date_norm.iloc[:, :-1] = w_df_date_norm.iloc[:, :-1].div(w_df_date_norm['weekly_weight_sum'], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab02c0b-5efa-41c6-afb1-e6bb9428f039",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"assurance fonds vie\",\n",
    "    \"logement immobilier revenu\",\n",
    "    \"femme pension enfant\",\n",
    "    \"capitalisation pension cotisation\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 5))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='seagreen', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Risk & Time dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.25), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54799bbf-38d7-4abd-b24c-ac2bfbff6a50",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"milliardaire riche fortune\",\n",
    "    \"agricole terre agriculteur\",\n",
    "    \"scpi rendement immobilier\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='seagreen', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Risk & Time dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e55bd8ca-bbb7-44cb-a494-60ce1352d961",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"maladie ministre catholique\",\n",
    "    \"consommation entreprise prix\",\n",
    "    \"per plan revenu\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=2).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='seagreen', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Risk & Time dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "761a90b8-15b1-4114-82aa-746d92a64b0d",
   "metadata": {},
   "source": [
    "### Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72f4eb1d-c5f2-48da-a995-661b2d2039bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the TF-IDF matrix\n",
    "tfidf_vectorizer = TfidfVectorizer(\n",
    "    max_df=0.95, min_df=5,  \n",
    "    max_features=n_features, \n",
    "    ngram_range=(1, 1))\n",
    "\n",
    "tfidf_process = tfidf_vectorizer.fit_transform(corpus_process_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0101e2d-c6d0-4b05-9fa7-22d7d2460e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform NMF\n",
    "seed = 42\n",
    "nmf_process = NMF(n_components=n_components,\n",
    "                  random_state=seed, # Pass an int for reproducible results across multiple function calls.\n",
    "                  alpha_W=0, # Constant that multiplies the regularization terms of W. Set it to zero (default) to have no regularization on W\n",
    "                  l1_ratio=0, # The regularization mixing parameter 0-1. l1_ratio = 0, penalty is an elementwise L2 penalty (aka Frobenius Norm). For l1_ratio = 1 it is an elementwise L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.\n",
    "                  max_iter=1000,\n",
    "                  init='random'\n",
    "                  ).fit(tfidf_process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9823622-cda7-4799-bc0c-41e2683e89cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the TF-IDF vectorizer\n",
    "tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "plot_top_words(\n",
    "    nmf_process, tfidf_feature_names, n_top_words, f\"Topics in articles on process dimension (N = {len(corpus_process_media)})\", 'scikit_selmedia_nofiltSW', n_components, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a3fa6a4-9156-4108-b3ac-580f5cef1600",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature names from the vectorizer\n",
    "feature_names = tfidf_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Get the top 3 indices for each topic\n",
    "top_indices = nmf_process.components_.argsort()[:,-5:]\n",
    "\n",
    "# Create empty list to store top words  \n",
    "top_words_process = []\n",
    "\n",
    "# Loop through each topic\n",
    "for topic in top_indices:\n",
    "  \n",
    "  # Extract top 3 words\n",
    "  top_3 = [feature_names[i] for i in topic]\n",
    "  \n",
    "  # Add to list\n",
    "  top_words_process.extend(top_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75f1a5ef-45fc-4858-bc91-4634c06f462d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the document-topic matrix\n",
    "doc_topic_matrix = nmf_process.transform(tfidf_process)\n",
    "\n",
    "# Calculate topic frequencies across corpus\n",
    "topic_freqs = doc_topic_matrix.sum(axis=0) \n",
    "\n",
    "# Get top 3 words for each topic\n",
    "top_word_indices = nmf_process.components_.argsort()[:, -3:]\n",
    "top_words = [[tfidf_feature_names[i] for i in topic[::-1]] for topic in top_word_indices]\n",
    "top_words = [' '.join(words) for words in top_words]\n",
    "\n",
    "# Combine topics and top words\n",
    "topics_with_words = list(zip(topic_freqs, top_words))\n",
    "\n",
    "# Sort by frequency\n",
    "sorted_topics = sorted(topics_with_words, key=lambda x: x[0], reverse=False)\n",
    "\n",
    "# Extract sorted freqs and topic names  \n",
    "sorted_freqs, sorted_top_words = zip(*sorted_topics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7305207-07f3-4865-934e-7304e9f6c74c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot sorted topics and freqs  \n",
    "fig, ax = plt.subplots(figsize=(3.5, 6)) \n",
    "\n",
    "# Plot horizontal bars\n",
    "ax.barh(range(n_components), sorted_freqs, color='coral')\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax.set_yticks(range(n_components)) \n",
    "ax.set_yticklabels(sorted_top_words)\n",
    "\n",
    "# Remove x-axis tick labels  \n",
    "ax.set_xticks([]) \n",
    "\n",
    "# Axis labels\n",
    "ax.set_xlabel(\"Frequency\")\n",
    "\n",
    "# Title and tight layout\n",
    "ax.set_title(f\"Process\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef00f9ff-4949-4d8a-9cd6-68b3dcd7b464",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Document-topic matrix, where columns are the top words\n",
    "w_df = pd.DataFrame(doc_topic_matrix)\n",
    "w_df.columns = top_words\n",
    "w_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bcdd158-da8c-4d4a-b36f-69a3baa290a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_df_date = w_df.copy()\n",
    "w_df_date['Date'] = top_process_media['Date'].values\n",
    "\n",
    "# Convert 'Date' column to datetime and set it as the index\n",
    "w_df_date['Date'] = pd.to_datetime(w_df_date['Date'])\n",
    "w_df_date.set_index('Date', inplace=True)\n",
    "\n",
    "# Resample at weekly level\n",
    "w_df_date = w_df_date.resample('W').sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "566526de-3f04-49b1-af13-b5c488454bf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot\n",
    "fig, ax = plt.subplots(figsize=(12,5))\n",
    "\n",
    "# Plot a stackplot\n",
    "ax.stackplot(w_df_date.index, w_df_date.T, baseline='sym', labels=w_df_date.columns, colors=colors)\n",
    "\n",
    "# Move the legend off of the chart\n",
    "ax.legend(loc=(0,-0.6), ncol=3)\n",
    "\n",
    "plt.title(f'Topic evolution over time in media articles about the process dimension (N = {len(w_df)})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de9b236c-741c-462a-9281-f84cc01bc12e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divide each value of the time series by the total weight across all topics\n",
    "w_df_date_norm = w_df_date.copy()\n",
    "w_df_date_norm['weekly_weight_sum'] = w_df_date_norm.apply(np.sum, axis=1)\n",
    "w_df_date_norm.iloc[:, :-1] = w_df_date_norm.iloc[:, :-1].div(w_df_date_norm['weekly_weight_sum'], axis=0)\n",
    "\n",
    "w_df_date_norm.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df80cfd4-85fd-412b-ad5b-fc332f077d97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"motion censure 49\",\n",
    "    \"cgt syndicat cfdt\",\n",
    "    \"conseil constitutionnel rip\",\n",
    "    \"ministre borne macron\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 5))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=3).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='coral', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Process dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.25), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2337ae55-8ee7-420f-aaf2-481056f7cbcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"liot proposition groupe\",\n",
    "    \"nupes gauche communiste\",\n",
    "    \"lr immigration texte\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=3).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='coral', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Process dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d3b95d1-e59e-46dc-b7e2-0a49ec969ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot selected topics\n",
    "selected_columns = [\n",
    "    \"vote texte gouvernement\",\n",
    "    \"politique social peuple\",\n",
    "    \"amendement lfi insoumis\",\n",
    "]\n",
    "\n",
    "# Define line styles\n",
    "line_styles = ['-', '--', '-.', ':']\n",
    "\n",
    "# Create a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "# Loop through each selected column, apply rolling average to normalized data, and plot the smoothed time series\n",
    "for i, column in enumerate(selected_columns):\n",
    "    smoothed_series = w_df_date_norm[column].rolling(window=3).mean()\n",
    "    ax.plot(w_df_date.index, smoothed_series, label=f'{column}', color='coral', linestyle=line_styles[i])\n",
    "\n",
    "# Set the labels and legend\n",
    "ax.set_xlabel('Date')\n",
    "ax.margins(x=0)\n",
    "ax.set_ylabel('Normalised topic weight')\n",
    "#ax.set_title(f'Process dimension in media: Evolution of selected topics')\n",
    "ax.legend(loc=(0,-0.3), ncol=4)\n",
    "\n",
    "plt.xticks(rotation=0)\n",
    "\n",
    "# Remove top and right borders\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Add vertical lines and annotations\n",
    "for date, description in important_dates.items():\n",
    "    plt.axvline(x=date, linewidth=.3, alpha=0.5)  # Add vertical line\n",
    "    plt.text(date, plt.ylim()[1] - 0.003, description, rotation=0, va='top', ha='right', fontsize=9, color='grey')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc5c480e-95b0-47f8-811c-eba398d243fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_df_art = w_df.copy()\n",
    "w_df_art['document'] = top_process_media['article'].to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b66fcab-616b-4ede-8c22-2f2b23b0bd7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace 'path_to_folder' with the path where you want to create the folder\n",
    "output_folder = '/top_docs/process/'\n",
    "os.makedirs(output_folder, exist_ok=True)\n",
    "\n",
    "topic_columns = w_df_art.columns[:10]\n",
    "\n",
    "for i, topic_col in enumerate(topic_columns):\n",
    "    # Sort documents based on the topic column\n",
    "    sorted_docs = w_df_art.sort_values(by=topic_col, ascending=False)\n",
    "\n",
    "    # Extract top 10 documents\n",
    "    top_docs = sorted_docs.head(10)\n",
    "\n",
    "    # Create file content\n",
    "    file_content = f\"Topic #{i}: {topic_col}\\n\\n\"\n",
    "    for index, row in top_docs.iterrows():\n",
    "        file_content += f\"Document: {row['document']}\\n\\n- - - - - - - - - -\\n\\n\"\n",
    "\n",
    "    # Write to file\n",
    "    file_path = os.path.join(output_folder, f'top_doc_topic{i}_{topic_col}.txt')\n",
    "    with open(file_path, 'w', encoding='utf-8') as file:\n",
    "        file.write(file_content)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2d1353a-b365-4d4d-b42b-6770735afbbe",
   "metadata": {},
   "source": [
    "# Coupled matrix factorisation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5b76cde-c283-41ec-9bbf-dab5d30a76bb",
   "metadata": {},
   "source": [
    "## Media vs speeches"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b91890cc-9eaa-4129-ad21-517ce4a8b56b",
   "metadata": {},
   "source": [
    "### General"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc37674-55a7-44d5-a749-0c6db3471abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine the processed articles from both parliamentary and media sources into a single list\n",
    "documents = sim_df_parl['article_processed'].to_list() + sim_df_media['article_processed'].to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03827f0a-8159-493a-9abb-3f4089f300e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize variables which will be used later\n",
    "n_features = 5000\n",
    "\n",
    "# Create TF matrix\n",
    "count_vectorizer = CountVectorizer(max_df=0.95, min_df=5, \n",
    "                                   max_features=n_features, \n",
    "                                   ngram_range=(1, 1))\n",
    "\n",
    "count_matrix = count_vectorizer.fit_transform(documents)\n",
    "\n",
    "# Split the count_matrix into two matrices for each corpus\n",
    "count_matrix_parl = count_matrix[:len(sim_df_parl), :]\n",
    "count_matrix_media = count_matrix[-len(sim_df_media):, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37a28cf1-64c0-4535-b5d8-cc0c3882174d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate TF-IDF values for each corpus\n",
    "tfidf_transformer = TfidfTransformer()\n",
    "tfidf_matrix_parl = tfidf_transformer.fit_transform(count_matrix_parl,)\n",
    "tfidf_matrix_media = tfidf_transformer.fit_transform(count_matrix_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1a2b28-0f21-4496-8601-b9abc3f40e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare matrices for CMF\n",
    "matrices = [np.array(tfidf_matrix_parl.todense()), np.array(tfidf_matrix_media.todense())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f4c2ce8-1ff1-4b29-9a6b-e4c788f5ae0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define parameters for Coupled Matrix Factorization (CMF)\n",
    "n_components = 10\n",
    "n_top_words = 10\n",
    "penalty_strength = 0.08\n",
    "\n",
    "# Perform Coupled Matrix Factorization (CMF) using the specified parameters\n",
    "cmf = decomposition.cmf_aoadmm(matrices, rank=n_components, \n",
    "                               non_negative=True,\n",
    "                               l2_penalty=[0,penalty_strength,0],\n",
    "                               l1_penalty=[0,0,penalty_strength],\n",
    "                               init='random',\n",
    "                               n_iter_max=1000,\n",
    "                              )\n",
    "\n",
    "# Extract the factorization results: weights and factors A, B_is, and C\n",
    "weights, (A, B_is, C) = cmf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08d567b7-6756-40c2-be6a-93178d7ec317",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get feature names (words) from the TF-IDF vectorizer\n",
    "feature_names = count_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Initialize an empty list to store top words for each topic\n",
    "top_words_list = []\n",
    "top_3_words = []\n",
    "\n",
    "# Print the top words for each topic\n",
    "for topic_idx, topic in enumerate(C.T):\n",
    "    top_words_idx = topic.argsort()[:-10 - 1:-1]\n",
    "    top_words = [feature_names[i] for i in top_words_idx]\n",
    "    \n",
    "    # Append the list of top words for the current topic to the main list\n",
    "    top_words_list.append(top_words)\n",
    "\n",
    "    # Concatenate the top three words into a single string\n",
    "    top_words_string = ' '.join(top_words[:3])\n",
    "    top_3_words.append(top_words_string)\n",
    "    \n",
    "    print(f\"Topic #{topic_idx + 1}: {', '.join(top_words)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b15df726-0f39-4fac-af76-776cc6b8c5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define function to compute L2 norm of a matrix\n",
    "def l2_norm(matrix):\n",
    "    \"\"\"Computes the L2 norm of a matrix.\"\"\"\n",
    "    return np.sqrt(np.sum(np.square(matrix)))\n",
    "\n",
    "# Define fuction to compute L2 in specific situation\n",
    "def norm_topics_in_corpus(i, n_components, A, B_is, C):\n",
    "    \"Computes the L2 norm of the product of the component matrices which describe one topic in one corpus for all the components\"\n",
    "    \n",
    "    # i is the index of the corpus (hence, 0 or 1 in case we perform coupled matrix factorisation with two matrices)\n",
    "    # n_components is the number of components - in topic modeling topics - that we extract from the corpora\n",
    "    \n",
    "    topic_weights = []\n",
    "\n",
    "    # Computes the L2 norm of the product of all the matrices\n",
    "    Ai = A[i]\n",
    "    Bi = B_is[i]\n",
    "    Ai_rs = Ai.reshape(-1, 1)\n",
    "    Bi_rs = Bi.reshape(-1, 1)\n",
    "    corpus_i = np.dot(np.dot(Ai_rs, Bi_rs.T).T, C.T)\n",
    "    corpus_norm_i = l2_norm(corpus_i)\n",
    "    \n",
    "    # Computes the weight of each topic in corpus as norm of the product of the component matrices divided by corpus-wide norm\n",
    "    for t in range(0, n_components):\n",
    "        Ait = A[i][t]\n",
    "        Bit = B_is[i][:, t]\n",
    "        Ct = C[:,t]\n",
    "    \n",
    "        Ait_rs = Ait.reshape(-1, 1)\n",
    "        Bit_rs = Bit.reshape(-1, 1)\n",
    "        Ct_rs = Ct.reshape(-1, 1)\n",
    "    \n",
    "        topic_component_it = np.dot(np.dot(Ait_rs, Bit_rs.T).T, Ct_rs.T)\n",
    "        topic_norm_it = l2_norm(topic_component_it)\n",
    "        topic_weight_it = topic_norm_it/corpus_norm_i\n",
    "        \n",
    "        topic_weights.append(topic_weight_it)\n",
    "        \n",
    "    return topic_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d0b2de-3150-4dd8-bfdb-05b111d4004f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate the normalized topic weights for two corpora (index 0 and 1) using Coupled Matrix Factorization (CMF)\n",
    "topic_weights_0 = norm_topics_in_corpus(0, n_components, A, B_is, C)\n",
    "topic_weights_1 = norm_topics_in_corpus(1, n_components, A, B_is, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c81f030-f117-4aae-a58a-bb062ec2557a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_parl = pd.DataFrame(data=[topic_weights_0], index=['Parliament Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_parl = df_topics_parl.iloc[:, df_topics_parl.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_parl.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Parliament Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_parl.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b995d5b-f2a3-4343-818d-2e5968910eb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_media = pd.DataFrame(data=[topic_weights_1], index=['Media Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_media = df_topics_media.iloc[:, df_topics_media.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_media.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Media Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_media.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "055041c4-f972-4183-9cf2-170f5cb9c4f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure with two subplots\n",
    "fig, axs = plt.subplots(1, 2, figsize=(6, 8), sharey=True)\n",
    "\n",
    "# Plot Parliament Corpus Topics\n",
    "axs[0].barh(range(1, n_components + 1), df_topics_parl.values[0][::-1], color='dodgerblue')\n",
    "axs[0].set_ylabel('Topic')\n",
    "axs[0].set_xlabel('Weight')\n",
    "axs[0].set_title('Parliament')\n",
    "axs[0].set_yticks(range(1, n_components + 1))\n",
    "axs[0].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "# Plot Media Corpus Topics\n",
    "axs[1].barh(range(1, n_components + 1), df_topics_media.values[0][::-1], color='dodgerblue')\n",
    "axs[1].set_xlabel('Weight')\n",
    "axs[1].set_title('Media')\n",
    "axs[1].set_yticks(range(1, n_components + 1))\n",
    "axs[1].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "fig.suptitle('Economy', fontsize=16)\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd0f4434-a8b8-475a-bff6-f82db3cf0764",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate the topic weight DataFrames for Parliament and Media corpora into one for comparison\n",
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(4, 6)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage Mismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "plt.title(f'Media and Parliament comparison', loc='right')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95439631-2944-43c2-876c-bd5fe84749dd",
   "metadata": {},
   "source": [
    "### Economy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acc75933-b12d-4ef8-850d-dabcb54fcac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine the processed documents from Parliament and Media corpora into one list of documents\n",
    "documents = corpus_economy_parl.to_list() + corpus_economy_media.to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c375d4db-9b08-4f74-988a-ea625734f18a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize variables which will be used later\n",
    "n_features = 5000\n",
    "\n",
    "# Create TF matrix\n",
    "count_vectorizer = CountVectorizer(max_df=0.95, min_df=5, \n",
    "                                   max_features=n_features, \n",
    "                                   ngram_range=(1, 1))\n",
    "\n",
    "count_matrix = count_vectorizer.fit_transform(documents)\n",
    "\n",
    "# Split the count_matrix into two matrices for each corpus\n",
    "count_matrix_parl = count_matrix[:len(corpus_economy_parl), :]\n",
    "count_matrix_media = count_matrix[-len(corpus_economy_media):, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4f2242-2f25-40a8-80eb-b2781ec38d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate TF-IDF values for each corpus\n",
    "tfidf_transformer = TfidfTransformer()\n",
    "tfidf_matrix_parl = tfidf_transformer.fit_transform(count_matrix_parl,)\n",
    "tfidf_matrix_media = tfidf_transformer.fit_transform(count_matrix_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27d2077-55f2-4eb1-915d-b73ea85400c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare matrices for CMF\n",
    "matrices = [np.array(tfidf_matrix_parl.todense()), np.array(tfidf_matrix_media.todense())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cf5001d-d4b9-4899-8660-1371a69f728e",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_components = 10\n",
    "n_top_words = 10\n",
    "penalty_strength = 0.08\n",
    "\n",
    "cmf = decomposition.cmf_aoadmm(matrices, rank=n_components, \n",
    "                               non_negative=True,\n",
    "                               l2_penalty=[0,penalty_strength,0],\n",
    "                               l1_penalty=[0,0,penalty_strength],\n",
    "                               init='random',\n",
    "                               n_iter_max=1000,\n",
    "                              )\n",
    "\n",
    "weights, (A, B_is, C) = cmf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a551260-b16d-47a9-bcaa-a6f47fc258f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get feature names (words) from the TF-IDF vectorizer\n",
    "feature_names = count_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Initialize an empty list to store top words for each topic\n",
    "top_words_list = []\n",
    "top_3_words = []\n",
    "\n",
    "# Print the top words for each topic\n",
    "for topic_idx, topic in enumerate(C.T):\n",
    "    top_words_idx = topic.argsort()[:-10 - 1:-1]\n",
    "    top_words = [feature_names[i] for i in top_words_idx]\n",
    "    \n",
    "    # Append the list of top words for the current topic to the main list\n",
    "    top_words_list.append(top_words)\n",
    "\n",
    "    # Concatenate the top three words into a single string\n",
    "    top_words_string = ' '.join(top_words[:3])\n",
    "    top_3_words.append(top_words_string)\n",
    "    \n",
    "    print(f\"Topic #{topic_idx + 1}: {', '.join(top_words)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6028e876-91d6-4c51-abf0-fb2f28d987a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_norm(matrix):\n",
    "    \"\"\"Computes the L2 norm of a matrix.\"\"\"\n",
    "    return np.sqrt(np.sum(np.square(matrix)))\n",
    "\n",
    "def norm_topics_in_corpus(i, n_components, A, B_is, C):\n",
    "    \"Computes the L2 norm of the product of the component matrices which describe one topic in one corpus for all the components\"\n",
    "    \n",
    "    # i is the index of the corpus (hence, 0 or 1 in case we perform coupled matrix factorisation with two matrices)\n",
    "    # n_components is the number of components - in topic modeling topics - that we extract from the corpora\n",
    "    \n",
    "    topic_weights = []\n",
    "\n",
    "    # Computes the L2 norm of the product of all the matrices\n",
    "    Ai = A[i]\n",
    "    Bi = B_is[i]\n",
    "    Ai_rs = Ai.reshape(-1, 1)\n",
    "    Bi_rs = Bi.reshape(-1, 1)\n",
    "    corpus_i = np.dot(np.dot(Ai_rs, Bi_rs.T).T, C.T)\n",
    "    corpus_norm_i = l2_norm(corpus_i)\n",
    "    \n",
    "    # Computes the weight of each topic in corpus as norm of the product of the component matrices divided by corpus-wide norm\n",
    "    for t in range(0, n_components):\n",
    "        Ait = A[i][t]\n",
    "        Bit = B_is[i][:, t]\n",
    "        Ct = C[:,t]\n",
    "    \n",
    "        Ait_rs = Ait.reshape(-1, 1)\n",
    "        Bit_rs = Bit.reshape(-1, 1)\n",
    "        Ct_rs = Ct.reshape(-1, 1)\n",
    "    \n",
    "        topic_component_it = np.dot(np.dot(Ait_rs, Bit_rs.T).T, Ct_rs.T)\n",
    "        topic_norm_it = l2_norm(topic_component_it)\n",
    "        topic_weight_it = topic_norm_it/corpus_norm_i\n",
    "        \n",
    "        topic_weights.append(topic_weight_it)\n",
    "        \n",
    "    return topic_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de6db2a4-6fa4-4f4a-a202-7af49a9e020e",
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_weights_0 = norm_topics_in_corpus(0, n_components, A, B_is, C)\n",
    "topic_weights_1 = norm_topics_in_corpus(1, n_components, A, B_is, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b48f3030-fa2f-498a-8f49-0a37ac187cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_parl = pd.DataFrame(data=[topic_weights_0], index=['Parliament Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_parl = df_topics_parl.iloc[:, df_topics_parl.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_parl.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Parliament Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_parl.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfcf12c8-25bd-4cce-ab82-a9ac5586427c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_media = pd.DataFrame(data=[topic_weights_1], index=['Media Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_media = df_topics_media.iloc[:, df_topics_media.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_media.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Media Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_media.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22cbd445-55cb-4d88-9a6a-f2f4b0264756",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure with two subplots\n",
    "fig, axs = plt.subplots(1, 2, figsize=(6, 8), sharey=True)\n",
    "\n",
    "# Plot Parliament Corpus Topics\n",
    "axs[0].barh(range(1, n_components + 1), df_topics_parl.values[0][::-1], color='dodgerblue')\n",
    "axs[0].set_ylabel('Topic')\n",
    "axs[0].set_xlabel('Weight')\n",
    "axs[0].set_title('Parliament')\n",
    "axs[0].set_yticks(range(1, n_components + 1))\n",
    "axs[0].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "# Plot Media Corpus Topics\n",
    "axs[1].barh(range(1, n_components + 1), df_topics_media.values[0][::-1], color='dodgerblue')\n",
    "axs[1].set_xlabel('Weight')\n",
    "axs[1].set_title('Media')\n",
    "axs[1].set_yticks(range(1, n_components + 1))\n",
    "axs[1].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "fig.suptitle('Economy', fontsize=16)\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec405ecd-6971-4b87-9dc7-17b1c812ee34",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "plt.title(f'Economy', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d38910b8-5db5-4bcf-a06f-9d6609ebfb1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "#plt.title(f'Economy', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bc73108-7a6b-43a4-a2ea-5ce1e097d1d8",
   "metadata": {},
   "source": [
    "### Fairness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc659ad6-367e-41d9-8846-c0cbba92bea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine the processed documents from Parliament and Media corpora into one list of documents\n",
    "documents = corpus_fairness_parl.to_list() + corpus_fairness_media.to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c8a558e-d882-431c-ad75-f9048fcabdcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize variables which will be used later\n",
    "n_features = 5000\n",
    "\n",
    "# Create TF matrix\n",
    "count_vectorizer = CountVectorizer(max_df=0.95, min_df=5, \n",
    "                                   max_features=n_features, \n",
    "                                   ngram_range=(1, 1))\n",
    "\n",
    "count_matrix = count_vectorizer.fit_transform(documents)\n",
    "\n",
    "# Split the count_matrix into two matrices for each corpus\n",
    "count_matrix_parl = count_matrix[:len(corpus_fairness_parl), :]\n",
    "count_matrix_media = count_matrix[-len(corpus_fairness_media):, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82d95a0d-32b2-4387-9f09-4e48aceb51ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate TF-IDF values for each corpus\n",
    "tfidf_transformer = TfidfTransformer()\n",
    "tfidf_matrix_parl = tfidf_transformer.fit_transform(count_matrix_parl,)\n",
    "tfidf_matrix_media = tfidf_transformer.fit_transform(count_matrix_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99c60f5a-ff7e-4f51-8675-084bd83c6cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare matrices for CMF\n",
    "matrices = [np.array(tfidf_matrix_parl.todense()), np.array(tfidf_matrix_media.todense())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f16c8095-2e85-4b14-92bd-1d460dbba6d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CMF parameters and implementation\n",
    "n_components = 10\n",
    "n_top_words = 10\n",
    "penalty_strength = 0.08\n",
    "\n",
    "cmf = decomposition.cmf_aoadmm(matrices, rank=n_components, \n",
    "                               non_negative=True,\n",
    "                               l2_penalty=[0,penalty_strength,0],\n",
    "                               l1_penalty=[0,0,penalty_strength],\n",
    "                               init='random',\n",
    "                               n_iter_max=1000,\n",
    "                              )\n",
    "\n",
    "weights, (A, B_is, C) = cmf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "013d8e8c-fbab-427c-8992-b8c09be194d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get feature names (words) from the TF-IDF vectorizer\n",
    "feature_names = count_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Initialize an empty list to store top words for each topic\n",
    "top_words_list = []\n",
    "top_3_words = []\n",
    "\n",
    "# Print the top words for each topic\n",
    "for topic_idx, topic in enumerate(C.T):\n",
    "    top_words_idx = topic.argsort()[:-10 - 1:-1]\n",
    "    top_words = [feature_names[i] for i in top_words_idx]\n",
    "    \n",
    "    # Append the list of top words for the current topic to the main list\n",
    "    top_words_list.append(top_words)\n",
    "\n",
    "    # Concatenate the top three words into a single string\n",
    "    top_words_string = ' '.join(top_words[:3])\n",
    "    top_3_words.append(top_words_string)\n",
    "    \n",
    "    print(f\"Topic #{topic_idx + 1}: {', '.join(top_words)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc08699-4633-4022-b0f7-b52d68e4cedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_norm(matrix):\n",
    "    \"\"\"Computes the L2 norm of a matrix.\"\"\"\n",
    "    return np.sqrt(np.sum(np.square(matrix)))\n",
    "\n",
    "def norm_topics_in_corpus(i, n_components, A, B_is, C):\n",
    "    \"Computes the L2 norm of the product of the component matrices which describe one topic in one corpus for all the components\"\n",
    "    \n",
    "    # i is the index of the corpus (hence, 0 or 1 in case we perform coupled matrix factorisation with two matrices)\n",
    "    # n_components is the number of components - in topic modeling topics - that we extract from the corpora\n",
    "    \n",
    "    topic_weights = []\n",
    "\n",
    "    # Computes the L2 norm of the product of all the matrices\n",
    "    Ai = A[i]\n",
    "    Bi = B_is[i]\n",
    "    Ai_rs = Ai.reshape(-1, 1)\n",
    "    Bi_rs = Bi.reshape(-1, 1)\n",
    "    corpus_i = np.dot(np.dot(Ai_rs, Bi_rs.T).T, C.T)\n",
    "    corpus_norm_i = l2_norm(corpus_i)\n",
    "    \n",
    "    # Computes the weight of each topic in corpus as norm of the product of the component matrices divided by corpus-wide norm\n",
    "    for t in range(0, n_components):\n",
    "        Ait = A[i][t]\n",
    "        Bit = B_is[i][:, t]\n",
    "        Ct = C[:,t]\n",
    "    \n",
    "        Ait_rs = Ait.reshape(-1, 1)\n",
    "        Bit_rs = Bit.reshape(-1, 1)\n",
    "        Ct_rs = Ct.reshape(-1, 1)\n",
    "    \n",
    "        topic_component_it = np.dot(np.dot(Ait_rs, Bit_rs.T).T, Ct_rs.T)\n",
    "        topic_norm_it = l2_norm(topic_component_it)\n",
    "        topic_weight_it = topic_norm_it/corpus_norm_i\n",
    "        \n",
    "        topic_weights.append(topic_weight_it)\n",
    "        \n",
    "    return topic_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50b4a86c-0e22-46d2-9226-b01f6f8a4e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Implement norm\n",
    "topic_weights_0 = norm_topics_in_corpus(0, n_components, A, B_is, C)\n",
    "topic_weights_1 = norm_topics_in_corpus(1, n_components, A, B_is, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c78d666-3a9a-4e07-b8ff-4f376d2ab063",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_parl = pd.DataFrame(data=[topic_weights_0], index=['Parliament Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_parl = df_topics_parl.iloc[:, df_topics_parl.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_parl.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Parliament Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_parl.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a9208f3-bf78-4533-bb5c-b9aa8cce9c82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_media = pd.DataFrame(data=[topic_weights_1], index=['Media Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_media = df_topics_media.iloc[:, df_topics_media.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_media.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Media Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_media.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8574cc7-c0cc-4204-ac18-4af6b29db409",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure with two subplots\n",
    "fig, axs = plt.subplots(1, 2, figsize=(6, 8), sharey=True)\n",
    "\n",
    "# Plot Parliament Corpus Topics\n",
    "axs[0].barh(range(1, n_components + 1), df_topics_parl.values[0][::-1], color='dodgerblue')\n",
    "axs[0].set_ylabel('Topic')\n",
    "axs[0].set_xlabel('Weight')\n",
    "axs[0].set_title('Parliament')\n",
    "axs[0].set_yticks(range(1, n_components + 1))\n",
    "axs[0].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "# Plot Media Corpus Topics\n",
    "axs[1].barh(range(1, n_components + 1), df_topics_media.values[0][::-1], color='dodgerblue')\n",
    "axs[1].set_xlabel('Weight')\n",
    "axs[1].set_title('Media')\n",
    "axs[1].set_yticks(range(1, n_components + 1))\n",
    "axs[1].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "fig.suptitle('Fairness', fontsize=16)\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9d30d83-5656-4050-8fa4-9bd4ae96e64d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate media and parliament DataFrames\n",
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "plt.title(f'Fairness', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16ab6fa4-00ed-41cf-94ed-469244623b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "#plt.title(f'Fairness', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51373d03-d608-48bb-a571-04565317c510",
   "metadata": {},
   "source": [
    "### Risk & Time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a58154a1-3bc9-438c-a06e-44a414ce04f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine the processed documents from Parliament and Media corpora into one list of documents\n",
    "documents = corpus_risktime_parl.to_list() + corpus_risktime_media.to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab80331-b182-4172-b0aa-3a4fefafe321",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize variables which will be used later\n",
    "n_features = 5000\n",
    "\n",
    "# Create TF matrix\n",
    "count_vectorizer = CountVectorizer(max_df=0.95, min_df=5, \n",
    "                                   max_features=n_features, \n",
    "                                   ngram_range=(1, 1))\n",
    "\n",
    "count_matrix = count_vectorizer.fit_transform(documents)\n",
    "\n",
    "# Split the count_matrix into two matrices for each corpus\n",
    "count_matrix_parl = count_matrix[:len(corpus_risktime_parl), :]\n",
    "count_matrix_media = count_matrix[-len(corpus_risktime_media):, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ec98248-e859-41e8-90f0-b88e4e2cee49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate TF-IDF values for each corpus\n",
    "tfidf_transformer = TfidfTransformer()\n",
    "tfidf_matrix_parl = tfidf_transformer.fit_transform(count_matrix_parl,)\n",
    "tfidf_matrix_media = tfidf_transformer.fit_transform(count_matrix_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce7ecd8-2b26-4c3a-bbc5-9fe815ed88ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare matrices for CMF\n",
    "matrices = [np.array(tfidf_matrix_parl.todense()), np.array(tfidf_matrix_media.todense())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4910632-742b-46ac-9ddf-313fa48db067",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CMF parameters and implementation\n",
    "n_components = 10\n",
    "n_top_words = 10\n",
    "penalty_strength = 0.08\n",
    "\n",
    "cmf = decomposition.cmf_aoadmm(matrices, rank=n_components, \n",
    "                               non_negative=True,\n",
    "                               l2_penalty=[0,penalty_strength,0],\n",
    "                               l1_penalty=[0,0,penalty_strength],\n",
    "                               init='random',\n",
    "                               n_iter_max=1000,\n",
    "                              )\n",
    "\n",
    "weights, (A, B_is, C) = cmf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11581d6c-8a9b-42a5-894a-9c0fc9b85509",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get feature names (words) from the TF-IDF vectorizer\n",
    "feature_names = count_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Initialize an empty list to store top words for each topic\n",
    "top_words_list = []\n",
    "top_3_words = []\n",
    "\n",
    "# Print the top words for each topic\n",
    "for topic_idx, topic in enumerate(C.T):\n",
    "    top_words_idx = topic.argsort()[:-10 - 1:-1]\n",
    "    top_words = [feature_names[i] for i in top_words_idx]\n",
    "    \n",
    "    # Append the list of top words for the current topic to the main list\n",
    "    top_words_list.append(top_words)\n",
    "\n",
    "    # Concatenate the top three words into a single string\n",
    "    top_words_string = ' '.join(top_words[:3])\n",
    "    top_3_words.append(top_words_string)\n",
    "    \n",
    "    print(f\"Topic #{topic_idx + 1}: {', '.join(top_words)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f38f19e7-b69b-4143-b519-d1ec9b6c2c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_norm(matrix):\n",
    "    \"\"\"Computes the L2 norm of a matrix.\"\"\"\n",
    "    return np.sqrt(np.sum(np.square(matrix)))\n",
    "\n",
    "def norm_topics_in_corpus(i, n_components, A, B_is, C):\n",
    "    \"Computes the L2 norm of the product of the component matrices which describe one topic in one corpus for all the components\"\n",
    "    \n",
    "    # i is the index of the corpus (hence, 0 or 1 in case we perform coupled matrix factorisation with two matrices)\n",
    "    # n_components is the number of components - in topic modeling topics - that we extract from the corpora\n",
    "    \n",
    "    topic_weights = []\n",
    "\n",
    "    # Computes the L2 norm of the product of all the matrices\n",
    "    Ai = A[i]\n",
    "    Bi = B_is[i]\n",
    "    Ai_rs = Ai.reshape(-1, 1)\n",
    "    Bi_rs = Bi.reshape(-1, 1)\n",
    "    corpus_i = np.dot(np.dot(Ai_rs, Bi_rs.T).T, C.T)\n",
    "    corpus_norm_i = l2_norm(corpus_i)\n",
    "    \n",
    "    # Computes the weight of each topic in corpus as norm of the product of the component matrices divided by corpus-wide norm\n",
    "    for t in range(0, n_components):\n",
    "        Ait = A[i][t]\n",
    "        Bit = B_is[i][:, t]\n",
    "        Ct = C[:,t]\n",
    "    \n",
    "        Ait_rs = Ait.reshape(-1, 1)\n",
    "        Bit_rs = Bit.reshape(-1, 1)\n",
    "        Ct_rs = Ct.reshape(-1, 1)\n",
    "    \n",
    "        topic_component_it = np.dot(np.dot(Ait_rs, Bit_rs.T).T, Ct_rs.T)\n",
    "        topic_norm_it = l2_norm(topic_component_it)\n",
    "        topic_weight_it = topic_norm_it/corpus_norm_i\n",
    "        \n",
    "        topic_weights.append(topic_weight_it)\n",
    "        \n",
    "    return topic_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fe5960b-00c1-4a9c-9da7-24423f5de2cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_weights_0 = norm_topics_in_corpus(0, n_components, A, B_is, C)\n",
    "topic_weights_1 = norm_topics_in_corpus(1, n_components, A, B_is, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee5b0702-005e-471a-8542-9b5652d07abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_parl = pd.DataFrame(data=[topic_weights_0], index=['Parliament Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_parl = df_topics_parl.iloc[:, df_topics_parl.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_parl.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Parliament Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_parl.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a63931e7-2998-4555-bcf4-58f5799cce01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_media = pd.DataFrame(data=[topic_weights_1], index=['Media Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_media = df_topics_media.iloc[:, df_topics_media.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_media.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Media Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_media.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dbfd6f7-bcdf-4c1b-b12e-d7a4ca6444ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure with two subplots\n",
    "fig, axs = plt.subplots(1, 2, figsize=(6, 8), sharey=True)\n",
    "\n",
    "# Plot Parliament Corpus Topics\n",
    "axs[0].barh(range(1, n_components + 1), df_topics_parl.values[0][::-1], color='dodgerblue')\n",
    "axs[0].set_ylabel('Topic')\n",
    "axs[0].set_xlabel('Weight')\n",
    "axs[0].set_title('Parliament')\n",
    "axs[0].set_yticks(range(1, n_components + 1))\n",
    "axs[0].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "# Plot Media Corpus Topics\n",
    "axs[1].barh(range(1, n_components + 1), df_topics_media.values[0][::-1], color='dodgerblue')\n",
    "axs[1].set_xlabel('Weight')\n",
    "axs[1].set_title('Media')\n",
    "axs[1].set_yticks(range(1, n_components + 1))\n",
    "axs[1].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "fig.suptitle('Risk & Time', fontsize=16)\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4104fce7-b3b7-432d-9e76-b53bc428761f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate parliament and media DataFrames\n",
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "plt.title(f'Risk & Time', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e09b44b8-0df5-4021-b41f-36aa0960f2b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate parliament and media DataFrames\n",
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "#plt.title(f'Risk & Time', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ac2e6a8-89e6-47ee-9783-3c3d0b2009d3",
   "metadata": {},
   "source": [
    "### Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a90d3575-55d5-490b-bda0-b39d84a84b7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine the processed documents from Parliament and Media corpora into one list of documents\n",
    "documents = corpus_process_parl.to_list() + corpus_process_media.to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eece4896-c28a-4662-832a-f6d697236069",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize variables which will be used later\n",
    "n_features = 5000\n",
    "\n",
    "# Create TF matrix\n",
    "count_vectorizer = CountVectorizer(max_df=0.95, min_df=5, \n",
    "                                   max_features=n_features, \n",
    "                                   ngram_range=(1, 1))\n",
    "\n",
    "count_matrix = count_vectorizer.fit_transform(documents)\n",
    "\n",
    "# Split the count_matrix into two matrices for each corpus\n",
    "count_matrix_parl = count_matrix[:len(corpus_process_parl), :]\n",
    "count_matrix_media = count_matrix[-len(corpus_process_media):, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afa542ea-9e1d-461d-9f34-fe1ecf208021",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate TF-IDF values for each corpus\n",
    "tfidf_transformer = TfidfTransformer()\n",
    "tfidf_matrix_parl = tfidf_transformer.fit_transform(count_matrix_parl,)\n",
    "tfidf_matrix_media = tfidf_transformer.fit_transform(count_matrix_media)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b60bb04-6c6f-4521-8b1a-a9e028fb64e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare matrices for CMF\n",
    "matrices = [np.array(tfidf_matrix_parl.todense()), np.array(tfidf_matrix_media.todense())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09f414a6-df56-49be-8f2f-12df9cc4700d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CMF parameters and implementation\n",
    "n_components = 10\n",
    "n_top_words = 10\n",
    "penalty_strength = 0.08\n",
    "\n",
    "cmf = decomposition.cmf_aoadmm(matrices, rank=n_components, \n",
    "                               non_negative=True,\n",
    "                               l2_penalty=[0,penalty_strength,0],\n",
    "                               l1_penalty=[0,0,penalty_strength],\n",
    "                               init='random',\n",
    "                               n_iter_max=1000,\n",
    "                              )\n",
    "\n",
    "weights, (A, B_is, C) = cmf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e055e962-6f56-4375-84c6-2c56a6f213b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get feature names (words) from the TF-IDF vectorizer\n",
    "feature_names = count_vectorizer.get_feature_names_out()\n",
    "\n",
    "# Initialize an empty list to store top words for each topic\n",
    "top_words_list = []\n",
    "top_3_words = []\n",
    "\n",
    "# Print the top words for each topic\n",
    "for topic_idx, topic in enumerate(C.T):\n",
    "    top_words_idx = topic.argsort()[:-10 - 1:-1]\n",
    "    top_words = [feature_names[i] for i in top_words_idx]\n",
    "    \n",
    "    # Append the list of top words for the current topic to the main list\n",
    "    top_words_list.append(top_words)\n",
    "\n",
    "    # Concatenate the top three words into a single string\n",
    "    top_words_string = ' '.join(top_words[:3])\n",
    "    top_3_words.append(top_words_string)\n",
    "    \n",
    "    print(f\"Topic #{topic_idx + 1}: {', '.join(top_words)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d63b3718-24fc-4e41-a337-d7cafbf94fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_norm(matrix):\n",
    "    \"\"\"Computes the L2 norm of a matrix.\"\"\"\n",
    "    return np.sqrt(np.sum(np.square(matrix)))\n",
    "\n",
    "def norm_topics_in_corpus(i, n_components, A, B_is, C):\n",
    "    \"Computes the L2 norm of the product of the component matrices which describe one topic in one corpus for all the components\"\n",
    "    \n",
    "    # i is the index of the corpus (hence, 0 or 1 in case we perform coupled matrix factorisation with two matrices)\n",
    "    # n_components is the number of components - in topic modeling topics - that we extract from the corpora\n",
    "    \n",
    "    topic_weights = []\n",
    "\n",
    "    # Computes the L2 norm of the product of all the matrices\n",
    "    Ai = A[i]\n",
    "    Bi = B_is[i]\n",
    "    Ai_rs = Ai.reshape(-1, 1)\n",
    "    Bi_rs = Bi.reshape(-1, 1)\n",
    "    corpus_i = np.dot(np.dot(Ai_rs, Bi_rs.T).T, C.T)\n",
    "    corpus_norm_i = l2_norm(corpus_i)\n",
    "    \n",
    "    # Computes the weight of each topic in corpus as norm of the product of the component matrices divided by corpus-wide norm\n",
    "    for t in range(0, n_components):\n",
    "        Ait = A[i][t]\n",
    "        Bit = B_is[i][:, t]\n",
    "        Ct = C[:,t]\n",
    "    \n",
    "        Ait_rs = Ait.reshape(-1, 1)\n",
    "        Bit_rs = Bit.reshape(-1, 1)\n",
    "        Ct_rs = Ct.reshape(-1, 1)\n",
    "    \n",
    "        topic_component_it = np.dot(np.dot(Ait_rs, Bit_rs.T).T, Ct_rs.T)\n",
    "        topic_norm_it = l2_norm(topic_component_it)\n",
    "        topic_weight_it = topic_norm_it/corpus_norm_i\n",
    "        \n",
    "        topic_weights.append(topic_weight_it)\n",
    "        \n",
    "    return topic_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c10303f-a631-4479-8d1f-f1d7aa12b482",
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_weights_0 = norm_topics_in_corpus(0, n_components, A, B_is, C)\n",
    "topic_weights_1 = norm_topics_in_corpus(1, n_components, A, B_is, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e08db27-f2e3-4e03-bcdd-e8d1e97bcb10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_parl = pd.DataFrame(data=[topic_weights_0], index=['Parliament Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_parl = df_topics_parl.iloc[:, df_topics_parl.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_parl.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Parliament Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_parl.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d9bb51-5c22-438e-af13-19ae235867bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame with one row containing the information from A[0]\n",
    "df_topics_media = pd.DataFrame(data=[topic_weights_1], index=['Media Topic Weights'], columns=top_3_words)\n",
    "\n",
    "# Sort the topics based on their weights\n",
    "sorted_topics_df_media = df_topics_media.iloc[:, df_topics_media.iloc[0].argsort()[::-1]]\n",
    "\n",
    "# Display the sorted horizontal histogram\n",
    "plt.figure(figsize=(2, 6))\n",
    "plt.barh(range(1, n_components + 1), sorted_topics_df_media.values[0][::-1], color='dodgerblue')\n",
    "plt.ylabel('Topic')\n",
    "plt.xlabel('Weight')\n",
    "plt.title('Topics in Media Corpus')\n",
    "plt.yticks(range(1, n_components + 1), sorted_topics_df_media.columns[::-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f5e4d45-ba69-4164-a258-27054f9df519",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a figure with two subplots\n",
    "fig, axs = plt.subplots(1, 2, figsize=(6, 8), sharey=True)\n",
    "\n",
    "# Plot Parliament Corpus Topics\n",
    "axs[0].barh(range(1, n_components + 1), df_topics_parl.values[0][::-1], color='dodgerblue')\n",
    "axs[0].set_ylabel('Topic')\n",
    "axs[0].set_xlabel('Weight')\n",
    "axs[0].set_title('Parliament')\n",
    "axs[0].set_yticks(range(1, n_components + 1))\n",
    "axs[0].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "# Plot Media Corpus Topics\n",
    "axs[1].barh(range(1, n_components + 1), df_topics_media.values[0][::-1], color='dodgerblue')\n",
    "axs[1].set_xlabel('Weight')\n",
    "axs[1].set_title('Media')\n",
    "axs[1].set_yticks(range(1, n_components + 1))\n",
    "axs[1].set_yticklabels(top_3_words[::-1])  # Use the same y-axis labels as the sorted one\n",
    "\n",
    "fig.suptitle('Process', fontsize=16)\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5dfde49-368e-41f9-ab9a-9687071f1735",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate parliament and media DataFrames\n",
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "plt.title(f'Process', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17826250-a439-4165-be76-12f9b693d22d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate parliament and media DataFrames\n",
    "df_corporacomparison = pd.concat([df_topics_parl, df_topics_media], axis=0)\n",
    "\n",
    "# Transpose the DataFrame to have topics as rows and corpora as columns\n",
    "df_corpcomp_T = df_corporacomparison.transpose()\n",
    "\n",
    "# Calculate the difference between Media and Parliament for each topic\n",
    "df_corpcomp_T['Difference'] = (df_corpcomp_T['Media Topic Weights'] - df_corpcomp_T['Parliament Topic Weights'])*100/(df_corpcomp_T['Media Topic Weights'].max() + df_corpcomp_T['Parliament Topic Weights'].max())\n",
    "\n",
    "# Sort the DataFrame based on the 'Difference' column\n",
    "df_corpcomp_T_sorted = df_corpcomp_T.sort_values(by='Difference', ascending=True)\n",
    "\n",
    "# Plot the data\n",
    "fig = plt.figure(figsize=(5, 5)) \n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "# Normalize the 'Difference' values for color assignment\n",
    "norm = plt.Normalize(vmin=df_corpcomp_T_sorted['Difference'].min(), vmax=df_corpcomp_T_sorted['Difference'].max())\n",
    "\n",
    "# Plot the bars with color based on the 'Difference' values using the 'RdYlBu' colormap\n",
    "bars = ax.barh(df_corpcomp_T_sorted.index, df_corpcomp_T_sorted['Difference'], color=plt.cm.PiYG(norm(df_corpcomp_T_sorted['Difference'])))\n",
    "\n",
    "# Set the labels and title\n",
    "plt.xlabel('Topic Coverage\\nMismatch (%)')\n",
    "#plt.ylabel('Topics')\n",
    "#plt.title(f'Process', loc='center')\n",
    "\n",
    "# Add light grey grid\n",
    "ax.grid(axis='x', linestyle=':', alpha=0.5, color='lightgrey', zorder=0)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "#fig.set_size_inches(8, 10, forward=True)\n",
    "#fig.savefig('/home/sparazzoli/lagrange-oecd/code/corpus_comparison/TopicCoverageDifference_February.png', dpi=300, bbox_inches='tight') # Set desired DPI value and pass 'bbox_inches' argument to remove white spaces around the edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b349421-03aa-48e5-a2fd-87bedaed473c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
